Skip to main content

cake_core/cake/
topology.rs

1use std::collections::HashMap;
2
3use crate::ModelType;
4use anyhow::Result;
5use lazy_static::lazy_static;
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8
9lazy_static! {
10    static ref LAYER_RANGE_PARSER: Regex = Regex::new(r"(?m)^(.+[^\d])(\d+)-(\d+)$").unwrap();
11}
12
13/// A single node (worker).
14#[derive(Clone, Debug, Serialize, Deserialize)]
15pub struct Node {
16    /// Address and port of the worker.
17    pub host: String,
18    /// Optional description.
19    pub description: Option<String>,
20    pub layers: Vec<String>,
21    /// Total VRAM (or system RAM) in bytes available on this node.
22    #[serde(default)]
23    pub vram_bytes: u64,
24    /// Approximate FP16 TFLOPS for this node.
25    #[serde(default)]
26    pub tflops: f64,
27    /// Backend description (e.g. "CUDA 12.4", "Apple M2 Max", "CPU").
28    #[serde(default)]
29    pub backend: String,
30    /// Hostname of the node.
31    #[serde(default)]
32    pub hostname: String,
33    /// Operating system (e.g. "linux", "macos", "windows").
34    #[serde(default)]
35    pub os: String,
36}
37
38impl Node {
39    /// Return true if this node hosts the specified layer.
40    pub fn is_text_model_layer_owner(&self, full_layer_name: &str) -> bool {
41        for prefix in self.layers.iter() {
42            if full_layer_name.starts_with(&format!("{}.", prefix)) {
43                return true;
44            }
45        }
46
47        false
48    }
49}
50
51/// The topology is a worker-name -> worker-info map.
52#[derive(Clone, Debug, Serialize, Deserialize)]
53pub struct Topology(HashMap<String, Node>);
54
55impl Topology {
56    /// Create a new empty topology.
57    pub fn new() -> Self {
58        Self(HashMap::new())
59    }
60
61    /// Load the topology from a yaml file.
62    pub fn from_path(path: &str, model_type: &ModelType) -> Result<Self> {
63        log::info!("loading topology from {}", path);
64
65        let mut topology: Self = serde_yaml::from_str(&std::fs::read_to_string(path)?)
66            .map_err(|e| anyhow!("can't read {path}: {e}"))?;
67
68        if *model_type == ModelType::TextModel {
69            // check for range expressions
70            for (_worker_name, node) in topology.iter_mut() {
71                let mut layers = vec![];
72                for layer_name in &node.layers {
73                    if let Some(caps) = LAYER_RANGE_PARSER.captures_iter(layer_name).next() {
74                        let base = caps.get(1).unwrap().as_str().to_string();
75                        let start = caps.get(2).unwrap().as_str().to_string().parse::<usize>()?;
76                        let stop = caps.get(3).unwrap().as_str().to_string().parse::<usize>()?;
77
78                        if stop < start {
79                            return Err(anyhow!(
80                                "invalid range expression {layer_name}, end must be >= start"
81                            ));
82                        }
83
84                        for n in start..=stop {
85                            layers.push(format!("{}{}", base, n));
86                        }
87                    } else {
88                        layers.push(layer_name.to_string());
89                    }
90                }
91
92                node.layers = layers;
93            }
94        }
95
96        Ok(topology)
97    }
98
99    /// Return a set of all layer names assigned to workers in this topology.
100    pub fn all_worker_layers(&self) -> std::collections::HashSet<String> {
101        let mut layers = std::collections::HashSet::new();
102        for node in self.0.values() {
103            for layer in &node.layers {
104                layers.insert(layer.clone());
105            }
106        }
107        layers
108    }
109
110    /// Return the node serving the specified layer, or None if not found.
111    pub fn get_node_for_layer(&self, layer_name: &str) -> Option<(&str, &Node)> {
112        for (node_name, node) in &self.0 {
113            for node_layer_name in &node.layers {
114                if layer_name == node_layer_name {
115                    return Some((node_name, node));
116                }
117            }
118        }
119        None
120    }
121}
122
123impl std::ops::Deref for Topology {
124    type Target = HashMap<String, Node>;
125    fn deref(&self) -> &HashMap<String, Node> {
126        &self.0
127    }
128}
129
130impl std::ops::DerefMut for Topology {
131    fn deref_mut(&mut self) -> &mut Self::Target {
132        &mut self.0
133    }
134}