cake_core/cake/
topology.rs1use std::collections::HashMap;
2
3use crate::ModelType;
4use anyhow::Result;
5use lazy_static::lazy_static;
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8
9lazy_static! {
10 static ref LAYER_RANGE_PARSER: Regex = Regex::new(r"(?m)^(.+[^\d])(\d+)-(\d+)$").unwrap();
11}
12
13#[derive(Clone, Debug, Serialize, Deserialize)]
15pub struct Node {
16 pub host: String,
18 pub description: Option<String>,
20 pub layers: Vec<String>,
21 #[serde(default)]
23 pub vram_bytes: u64,
24 #[serde(default)]
26 pub tflops: f64,
27 #[serde(default)]
29 pub backend: String,
30 #[serde(default)]
32 pub hostname: String,
33 #[serde(default)]
35 pub os: String,
36}
37
38impl Node {
39 pub fn is_text_model_layer_owner(&self, full_layer_name: &str) -> bool {
41 for prefix in self.layers.iter() {
42 if full_layer_name.starts_with(&format!("{}.", prefix)) {
43 return true;
44 }
45 }
46
47 false
48 }
49}
50
51#[derive(Clone, Debug, Serialize, Deserialize)]
53pub struct Topology(HashMap<String, Node>);
54
55impl Topology {
56 pub fn new() -> Self {
58 Self(HashMap::new())
59 }
60
61 pub fn from_path(path: &str, model_type: &ModelType) -> Result<Self> {
63 log::info!("loading topology from {}", path);
64
65 let mut topology: Self = serde_yaml::from_str(&std::fs::read_to_string(path)?)
66 .map_err(|e| anyhow!("can't read {path}: {e}"))?;
67
68 if *model_type == ModelType::TextModel {
69 for (_worker_name, node) in topology.iter_mut() {
71 let mut layers = vec![];
72 for layer_name in &node.layers {
73 if let Some(caps) = LAYER_RANGE_PARSER.captures_iter(layer_name).next() {
74 let base = caps.get(1).unwrap().as_str().to_string();
75 let start = caps.get(2).unwrap().as_str().to_string().parse::<usize>()?;
76 let stop = caps.get(3).unwrap().as_str().to_string().parse::<usize>()?;
77
78 if stop < start {
79 return Err(anyhow!(
80 "invalid range expression {layer_name}, end must be >= start"
81 ));
82 }
83
84 for n in start..=stop {
85 layers.push(format!("{}{}", base, n));
86 }
87 } else {
88 layers.push(layer_name.to_string());
89 }
90 }
91
92 node.layers = layers;
93 }
94 }
95
96 Ok(topology)
97 }
98
99 pub fn all_worker_layers(&self) -> std::collections::HashSet<String> {
101 let mut layers = std::collections::HashSet::new();
102 for node in self.0.values() {
103 for layer in &node.layers {
104 layers.insert(layer.clone());
105 }
106 }
107 layers
108 }
109
110 pub fn get_node_for_layer(&self, layer_name: &str) -> Option<(&str, &Node)> {
112 for (node_name, node) in &self.0 {
113 for node_layer_name in &node.layers {
114 if layer_name == node_layer_name {
115 return Some((node_name, node));
116 }
117 }
118 }
119 None
120 }
121}
122
123impl std::ops::Deref for Topology {
124 type Target = HashMap<String, Node>;
125 fn deref(&self) -> &HashMap<String, Node> {
126 &self.0
127 }
128}
129
130impl std::ops::DerefMut for Topology {
131 fn deref_mut(&mut self) -> &mut Self::Target {
132 &mut self.0
133 }
134}