scud/attractor/
stylesheet.rs1use anyhow::{bail, Result};
12use std::collections::HashMap;
13
14use super::graph::PipelineGraph;
15
16#[derive(Debug, Clone)]
18pub struct StyleRule {
19 pub selector: Selector,
20 pub properties: HashMap<String, String>,
21}
22
23#[derive(Debug, Clone)]
25pub enum Selector {
26 Universal,
28 Class(String),
30 Id(String),
32}
33
34impl Selector {
35 pub fn specificity(&self) -> u8 {
37 match self {
38 Selector::Universal => 0,
39 Selector::Class(_) => 1,
40 Selector::Id(_) => 2,
41 }
42 }
43
44 pub fn matches(&self, node_id: &str, node_classes: &[String]) -> bool {
46 match self {
47 Selector::Universal => true,
48 Selector::Class(class) => node_classes.iter().any(|c| c == class),
49 Selector::Id(id) => node_id == id,
50 }
51 }
52}
53
54pub fn parse_stylesheet(input: &str) -> Result<Vec<StyleRule>> {
63 let mut rules = Vec::new();
64 let mut chars = input.chars().peekable();
65
66 loop {
67 while chars.peek().map(|c| c.is_whitespace()).unwrap_or(false) {
69 chars.next();
70 }
71
72 if chars.peek().is_none() {
73 break;
74 }
75
76 let selector = parse_selector(&mut chars)?;
78
79 while chars.peek().map(|c| c.is_whitespace()).unwrap_or(false) {
81 chars.next();
82 }
83
84 match chars.next() {
86 Some('{') => {}
87 _ => bail!("Expected '{{' after selector"),
88 }
89
90 let mut properties = HashMap::new();
92 loop {
93 while chars.peek().map(|c| c.is_whitespace()).unwrap_or(false) {
95 chars.next();
96 }
97
98 if chars.peek() == Some(&'}') {
99 chars.next();
100 break;
101 }
102
103 if chars.peek().is_none() {
104 bail!("Unterminated rule block");
105 }
106
107 let mut name = String::new();
109 while let Some(&c) = chars.peek() {
110 if c == ':' || c.is_whitespace() {
111 break;
112 }
113 name.push(c);
114 chars.next();
115 }
116
117 while chars.peek().map(|c| c.is_whitespace()).unwrap_or(false) {
119 chars.next();
120 }
121 if chars.peek() == Some(&':') {
122 chars.next();
123 }
124 while chars.peek().map(|c| c.is_whitespace()).unwrap_or(false) {
125 chars.next();
126 }
127
128 let value = if chars.peek() == Some(&'"') {
130 chars.next(); let mut v = String::new();
132 while let Some(c) = chars.next() {
133 if c == '"' {
134 break;
135 }
136 v.push(c);
137 }
138 v
139 } else {
140 let mut v = String::new();
141 while let Some(&c) = chars.peek() {
142 if c == ';' || c == '}' || c.is_whitespace() {
143 break;
144 }
145 v.push(c);
146 chars.next();
147 }
148 v
149 };
150
151 if !name.is_empty() {
152 properties.insert(name, value);
153 }
154
155 while chars.peek().map(|c| c.is_whitespace()).unwrap_or(false) {
157 chars.next();
158 }
159 if chars.peek() == Some(&';') {
160 chars.next();
161 }
162 }
163
164 rules.push(StyleRule {
165 selector,
166 properties,
167 });
168 }
169
170 Ok(rules)
171}
172
173fn parse_selector(
174 chars: &mut std::iter::Peekable<std::str::Chars>,
175) -> Result<Selector> {
176 match chars.peek() {
177 Some('*') => {
178 chars.next();
179 Ok(Selector::Universal)
180 }
181 Some('.') => {
182 chars.next();
183 let mut name = String::new();
184 while let Some(&c) = chars.peek() {
185 if c.is_alphanumeric() || c == '_' || c == '-' {
186 name.push(c);
187 chars.next();
188 } else {
189 break;
190 }
191 }
192 Ok(Selector::Class(name))
193 }
194 Some('#') => {
195 chars.next();
196 let mut name = String::new();
197 while let Some(&c) = chars.peek() {
198 if c.is_alphanumeric() || c == '_' || c == '-' {
199 name.push(c);
200 chars.next();
201 } else {
202 break;
203 }
204 }
205 Ok(Selector::Id(name))
206 }
207 Some(c) => bail!("Invalid selector start: '{}'", c),
208 None => bail!("Expected selector, got EOF"),
209 }
210}
211
212pub fn apply_stylesheet(graph: &mut PipelineGraph, rules: &[StyleRule]) {
217 let mut sorted_rules: Vec<_> = rules.iter().collect();
219 sorted_rules.sort_by_key(|r| r.selector.specificity());
220
221 for node_idx in graph.graph.node_indices() {
222 let (node_id, node_classes, has_model, has_provider, has_effort) = {
223 let node = &graph.graph[node_idx];
224 (
225 node.id.clone(),
226 node.classes.clone(),
227 node.llm_model.is_some(),
228 node.llm_provider.is_some(),
229 node.reasoning_effort != "high", )
231 };
232
233 for rule in &sorted_rules {
234 if rule.selector.matches(&node_id, &node_classes) {
235 let node = &mut graph.graph[node_idx];
236
237 if let Some(model) = rule.properties.get("model") {
239 if !has_model {
240 node.llm_model = Some(model.clone());
241 }
242 }
243 if let Some(provider) = rule.properties.get("provider") {
244 if !has_provider {
245 node.llm_provider = Some(provider.clone());
246 }
247 }
248 if let Some(effort) = rule.properties.get("reasoning_effort") {
249 if !has_effort {
250 node.reasoning_effort = effort.clone();
251 }
252 }
253 }
254 }
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_parse_stylesheet() {
264 let input = r#"
265 * { model: "claude-3-haiku"; reasoning_effort: "medium" }
266 .critical { model: "claude-3-opus" }
267 #special_node { provider: "anthropic" }
268 "#;
269 let rules = parse_stylesheet(input).unwrap();
270 assert_eq!(rules.len(), 3);
271 assert!(matches!(rules[0].selector, Selector::Universal));
272 assert!(matches!(rules[1].selector, Selector::Class(ref c) if c == "critical"));
273 assert!(matches!(rules[2].selector, Selector::Id(ref id) if id == "special_node"));
274 }
275
276 #[test]
277 fn test_selector_specificity() {
278 assert_eq!(Selector::Universal.specificity(), 0);
279 assert_eq!(Selector::Class("x".into()).specificity(), 1);
280 assert_eq!(Selector::Id("x".into()).specificity(), 2);
281 }
282
283 #[test]
284 fn test_selector_matches() {
285 assert!(Selector::Universal.matches("any", &[]));
286 assert!(Selector::Class("fast".into()).matches("x", &["fast".into()]));
287 assert!(!Selector::Class("fast".into()).matches("x", &["slow".into()]));
288 assert!(Selector::Id("x".into()).matches("x", &[]));
289 assert!(!Selector::Id("x".into()).matches("y", &[]));
290 }
291
292 #[test]
293 fn test_apply_stylesheet() {
294 use crate::attractor::dot_parser::parse_dot;
295 use crate::attractor::graph::PipelineGraph;
296
297 let input = r#"
298 digraph test {
299 graph [model_stylesheet="* { model: \"haiku\" }"]
300 start [shape=Mdiamond]
301 a [shape=box, class="fast"]
302 b [shape=box, llm_model="opus"]
303 finish [shape=Msquare]
304 start -> a -> b -> finish
305 }
306 "#;
307 let dot = parse_dot(input).unwrap();
308 let mut graph = PipelineGraph::from_dot(&dot).unwrap();
309
310 let rules = parse_stylesheet("* { model: \"haiku\" }").unwrap();
311 apply_stylesheet(&mut graph, &rules);
312
313 let a = graph.node("a").unwrap();
315 assert_eq!(a.llm_model, Some("haiku".into()));
316
317 let b = graph.node("b").unwrap();
319 assert_eq!(b.llm_model, Some("opus".into()));
320 }
321}