Skip to main content

gsm_runner/
flow_registry.rs

1use std::collections::{HashMap, HashSet};
2use std::fs;
3use std::path::{Path, PathBuf};
4
5use anyhow::{Context, Result, anyhow, bail};
6use greentic_pack::messaging::{MessagingAdapter, MessagingSection};
7use greentic_pack::reader::{SigningPolicy, open_pack};
8use gsm_core::{ChannelMessage, infer_platform_from_adapter_name};
9
10use crate::model::Flow;
11
12#[derive(Debug, Clone)]
13pub struct FlowDefinition {
14    pub pack_id: String,
15    #[allow(dead_code)]
16    pub pack_version: String,
17    pub flow_id: String,
18    pub platform: Option<String>,
19    pub route: Option<String>,
20    pub flow: Flow,
21}
22
23#[derive(Debug, Default)]
24pub struct FlowRegistry {
25    flows: Vec<FlowDefinition>,
26    by_route: HashMap<String, Vec<usize>>,
27    by_platform: HashMap<String, Vec<usize>>,
28    default_by_pack: HashMap<String, usize>,
29}
30
31impl FlowRegistry {
32    pub fn load_from_paths(root: &Path, paths: &[PathBuf]) -> Result<Self> {
33        let root = root
34            .canonicalize()
35            .with_context(|| format!("failed to canonicalize packs root {}", root.display()))?;
36        let mut flows: Vec<FlowDefinition> = Vec::new();
37        let mut pack_defaults: HashMap<String, String> = HashMap::new();
38
39        for path in paths {
40            let pack_path = resolve_pack_path(&root, path)?;
41            let ext = pack_path
42                .extension()
43                .and_then(|s| s.to_str())
44                .map(|s| s.to_ascii_lowercase());
45            match ext.as_deref() {
46                Some("gtpack") => {
47                    let (pack_flows, default_id) = flows_from_gtpack(&pack_path)?;
48                    if let Some(default_id) = default_id
49                        && let Some(pack_id) = pack_flows.first().map(|f| f.pack_id.clone())
50                    {
51                        pack_defaults.insert(pack_id, default_id);
52                    }
53                    flows.extend(pack_flows);
54                }
55                _ => {
56                    let (pack_flows, default_id) = flows_from_pack_yaml(&root, &pack_path)?;
57                    if let Some(default_id) = default_id
58                        && let Some(pack_id) = pack_flows.first().map(|f| f.pack_id.clone())
59                    {
60                        pack_defaults.insert(pack_id, default_id);
61                    }
62                    flows.extend(pack_flows);
63                }
64            }
65        }
66
67        flows.sort_by(|a, b| {
68            (a.pack_id.as_str(), a.flow_id.as_str(), a.route.as_deref()).cmp(&(
69                b.pack_id.as_str(),
70                b.flow_id.as_str(),
71                b.route.as_deref(),
72            ))
73        });
74
75        let mut registry = FlowRegistry {
76            flows,
77            ..Default::default()
78        };
79        for (idx, flow) in registry.flows.iter().enumerate() {
80            if let Some(route) = flow.route.as_ref() {
81                registry
82                    .by_route
83                    .entry(route.clone())
84                    .or_default()
85                    .push(idx);
86            }
87            if let Some(platform) = flow.platform.as_ref() {
88                registry
89                    .by_platform
90                    .entry(platform.clone())
91                    .or_default()
92                    .push(idx);
93            }
94        }
95
96        for (pack_id, flow_id) in pack_defaults {
97            if let Some(idx) = registry
98                .flows
99                .iter()
100                .position(|flow| flow.pack_id == pack_id && flow.flow_id == flow_id)
101            {
102                registry.default_by_pack.insert(pack_id, idx);
103            }
104        }
105
106        Ok(registry)
107    }
108
109    pub fn select_flow<'a>(&'a self, message: &ChannelMessage) -> Result<&'a FlowDefinition> {
110        if let Some(idx) = message
111            .route
112            .as_ref()
113            .and_then(|route| self.by_route.get(route))
114            .and_then(|indexes| indexes.first().copied())
115        {
116            return self
117                .flows
118                .get(idx)
119                .ok_or_else(|| anyhow!("flow index out of bounds"));
120        }
121
122        let platform = message.channel_id.as_str();
123        let mut candidates = self.by_platform.get(platform).cloned().unwrap_or_default();
124
125        if candidates.is_empty() {
126            if let Some(idx) = self.default_by_pack.values().min().copied() {
127                return self
128                    .flows
129                    .get(idx)
130                    .ok_or_else(|| anyhow!("flow index out of bounds"));
131            }
132            bail!("no flows registered for platform {platform}");
133        }
134
135        if candidates.len() == 1 {
136            return self
137                .flows
138                .get(candidates[0])
139                .ok_or_else(|| anyhow!("flow index out of bounds"));
140        }
141
142        candidates.sort();
143        for idx in &candidates {
144            if let Some(flow) = self.flows.get(*idx)
145                && self
146                    .default_by_pack
147                    .get(flow.pack_id.as_str())
148                    .is_some_and(|default_idx| default_idx == idx)
149            {
150                return self
151                    .flows
152                    .get(*idx)
153                    .ok_or_else(|| anyhow!("flow index out of bounds"));
154            }
155        }
156
157        self.flows
158            .get(candidates[0])
159            .ok_or_else(|| anyhow!("flow index out of bounds"))
160    }
161
162    #[allow(dead_code)]
163    pub fn get_flow(&self, flow_id: &str) -> Option<&FlowDefinition> {
164        self.flows.iter().find(|flow| flow.flow_id == flow_id)
165    }
166
167    pub fn is_empty(&self) -> bool {
168        self.flows.is_empty()
169    }
170}
171
172#[derive(Debug, serde::Deserialize)]
173struct PackSpec {
174    id: String,
175    version: String,
176    #[serde(default)]
177    messaging: Option<MessagingSection>,
178}
179
180fn resolve_pack_path(root: &Path, path: &Path) -> Result<PathBuf> {
181    if path.is_absolute() {
182        let canonical = path
183            .canonicalize()
184            .with_context(|| format!("failed to canonicalize {}", path.display()))?;
185        Ok(canonical)
186    } else {
187        gsm_core::path_safety::normalize_under_root(root, path)
188    }
189}
190
191fn flows_from_pack_yaml(root: &Path, path: &Path) -> Result<(Vec<FlowDefinition>, Option<String>)> {
192    let raw = fs::read_to_string(path)
193        .with_context(|| format!("failed to read pack file {}", path.display()))?;
194    let spec: PackSpec = serde_yaml_bw::from_str(&raw)
195        .with_context(|| format!("{} is not a valid pack spec", path.display()))?;
196
197    let mut flows = Vec::new();
198    let mut default_flow: Option<String> = spec
199        .messaging
200        .as_ref()
201        .and_then(|m| m.adapters.as_ref())
202        .and_then(|adapters| {
203            adapters
204                .iter()
205                .find_map(|adapter| adapter.default_flow.as_ref().map(|_| adapter))
206        })
207        .and_then(|adapter| adapter.default_flow.clone());
208    let pack_dir = path
209        .parent()
210        .ok_or_else(|| anyhow!("pack path missing parent: {}", path.display()))?;
211    let mut adapters = spec.messaging.and_then(|m| m.adapters).unwrap_or_default();
212    adapters.sort_by(|a, b| a.name.cmp(&b.name));
213
214    let mut flow_cache: HashMap<PathBuf, Flow> = HashMap::new();
215
216    for adapter in adapters {
217        let flow_path = adapter
218            .custom_flow
219            .as_ref()
220            .or(adapter.default_flow.as_ref());
221        let Some(flow_path) = flow_path else {
222            continue;
223        };
224        let resolved = resolve_flow_path(root, pack_dir, Path::new(flow_path))?;
225        let flow = if let Some(existing) = flow_cache.get(&resolved) {
226            existing.clone()
227        } else {
228            let loaded = Flow::load_from_file(resolved.to_str().unwrap())?;
229            flow_cache.insert(resolved.clone(), loaded.clone());
230            loaded
231        };
232        if default_flow.as_deref() == adapter.default_flow.as_deref() {
233            default_flow = Some(flow.id.clone());
234        }
235        flows.push(flow_definition_from_adapter(
236            spec.id.clone(),
237            spec.version.clone(),
238            &adapter,
239            flow.id.clone(),
240            flow,
241        ));
242    }
243
244    if default_flow.is_none() {
245        default_flow = flows.first().map(|flow| flow.flow_id.clone());
246    }
247
248    Ok((flows, default_flow))
249}
250
251fn flows_from_gtpack(path: &Path) -> Result<(Vec<FlowDefinition>, Option<String>)> {
252    let pack = open_pack(path, SigningPolicy::DevOk)
253        .map_err(|err| anyhow!(err.message))
254        .with_context(|| format!("failed to open {}", path.display()))?;
255
256    let pack_id = pack.manifest.meta.pack_id.clone();
257    let pack_version = pack.manifest.meta.version.to_string();
258    let mut flow_cache: HashMap<String, Flow> = HashMap::new();
259    let mut flows = Vec::new();
260    let mut registered: HashSet<String> = HashSet::new();
261
262    for entry in &pack.manifest.flows {
263        let yaml = pack
264            .files
265            .get(&entry.file_yaml)
266            .ok_or_else(|| anyhow!("missing flow file {}", entry.file_yaml))?;
267        let contents = String::from_utf8(yaml.clone())
268            .with_context(|| format!("flow file {} is not UTF-8", entry.file_yaml))?;
269        let flow = Flow::load_from_str(&entry.file_yaml, &contents)?;
270        flow_cache.insert(entry.id.clone(), flow);
271    }
272
273    if let Some(messaging) = pack.manifest.meta.messaging.as_ref()
274        && let Some(adapters) = messaging.adapters.as_ref()
275    {
276        for adapter in adapters {
277            if let Some(flow_id) = resolve_flow_id_for_adapter(adapter, &pack.manifest.flows)
278                && let Some(flow) = flow_cache.get(&flow_id).cloned()
279            {
280                flows.push(flow_definition_from_adapter(
281                    pack_id.clone(),
282                    pack_version.clone(),
283                    adapter,
284                    flow_id.clone(),
285                    flow,
286                ));
287                registered.insert(flow_id);
288            }
289        }
290    }
291
292    for (flow_id, flow) in flow_cache {
293        if registered.contains(&flow_id) {
294            continue;
295        }
296        flows.push(FlowDefinition {
297            pack_id: pack_id.clone(),
298            pack_version: pack_version.clone(),
299            flow_id,
300            platform: None,
301            route: None,
302            flow,
303        });
304    }
305
306    let default_flow = pack.manifest.meta.entry_flows.first().cloned();
307
308    Ok((flows, default_flow))
309}
310
311fn resolve_flow_id_for_adapter(
312    adapter: &MessagingAdapter,
313    flows: &[greentic_pack::builder::FlowEntry],
314) -> Option<String> {
315    let flow_path = adapter
316        .custom_flow
317        .as_ref()
318        .or(adapter.default_flow.as_ref())?;
319    flows
320        .iter()
321        .find(|entry| entry.file_yaml == *flow_path)
322        .map(|entry| entry.id.clone())
323}
324
325fn resolve_flow_path(root: &Path, pack_dir: &Path, path: &Path) -> Result<PathBuf> {
326    if path.is_absolute() {
327        bail!("absolute flow paths are not allowed: {}", path.display());
328    }
329    let joined = pack_dir.join(path);
330    let canon = joined
331        .canonicalize()
332        .with_context(|| format!("failed to canonicalize {}", joined.display()))?;
333    if !canon.starts_with(root) {
334        bail!(
335            "flow path escapes packs root ({}): {}",
336            root.display(),
337            canon.display()
338        );
339    }
340    Ok(canon)
341}
342
343fn flow_definition_from_adapter(
344    pack_id: String,
345    pack_version: String,
346    adapter: &MessagingAdapter,
347    flow_id: String,
348    flow: Flow,
349) -> FlowDefinition {
350    let platform = infer_platform_from_adapter_name(&adapter.name)
351        .map(|platform| platform.as_str().to_string());
352    FlowDefinition {
353        pack_id,
354        pack_version,
355        flow_id,
356        platform,
357        route: Some(adapter.name.clone()),
358        flow,
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use gsm_core::{ChannelMessage, make_tenant_ctx};
366    use std::fs;
367
368    fn temp_dir() -> PathBuf {
369        let base = std::env::temp_dir();
370        let dir = base.join(format!("flow-registry-{}", uuid::Uuid::new_v4()));
371        fs::create_dir_all(&dir).unwrap();
372        dir
373    }
374
375    fn write_flow(dir: &Path, name: &str, id: &str) -> PathBuf {
376        let path = dir.join(name);
377        let contents = format!(
378            r#"id: {id}
379type: messaging
380in: start
381nodes:
382  start:
383    routes: []
384"#
385        );
386        fs::write(&path, contents).unwrap();
387        path
388    }
389
390    #[test]
391    fn selects_flow_by_route_then_platform_then_default() {
392        let dir = temp_dir();
393        let flow_default = write_flow(&dir, "default.ygtc", "flow-default");
394        let flow_alt = write_flow(&dir, "alt.ygtc", "flow-alt");
395
396        let pack = format!(
397            r#"id: test-pack
398version: 1.0.0
399messaging:
400  adapters:
401    - name: slack-main
402      kind: ingress-egress
403      component: slack-adapter@1.0.0
404      default_flow: {}
405    - name: slack-alt
406      kind: ingress-egress
407      component: slack-adapter@1.0.0
408      default_flow: {}
409"#,
410            flow_default.file_name().unwrap().to_string_lossy(),
411            flow_alt.file_name().unwrap().to_string_lossy()
412        );
413        let pack_path = dir.join("pack.yaml");
414        fs::write(&pack_path, pack).unwrap();
415
416        let registry = FlowRegistry::load_from_paths(&dir, &[PathBuf::from("pack.yaml")]).unwrap();
417
418        let ctx = make_tenant_ctx("acme".into(), Some("team".into()), None);
419        let mut message = ChannelMessage {
420            tenant: ctx,
421            channel_id: "slack".into(),
422            session_id: "chat".into(),
423            route: Some("slack-alt".into()),
424            payload: serde_json::json!({
425                "chat_id": "chat",
426                "msg_id": "m1",
427                "timestamp": "2025-01-01T00:00:00Z"
428            }),
429        };
430
431        let selected = registry.select_flow(&message).unwrap();
432        assert_eq!(selected.flow_id, "flow-alt");
433
434        message.route = None;
435        let selected = registry.select_flow(&message).unwrap();
436        assert_eq!(selected.flow_id, "flow-default");
437    }
438}