Skip to main content

tandem_server/
capability_resolver.rs

1use std::collections::{BTreeMap, HashMap, HashSet};
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::anyhow;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use tokio::sync::Mutex;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct CapabilityBinding {
12    pub capability_id: String,
13    pub provider: String,
14    pub tool_name: String,
15    #[serde(default)]
16    pub tool_name_aliases: Vec<String>,
17    #[serde(default)]
18    pub request_transform: Option<Value>,
19    #[serde(default)]
20    pub response_transform: Option<Value>,
21    #[serde(default)]
22    pub metadata: Value,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct CapabilityBindingsFile {
27    pub schema_version: String,
28    #[serde(default)]
29    pub generated_at: Option<String>,
30    #[serde(default)]
31    pub bindings: Vec<CapabilityBinding>,
32}
33
34impl Default for CapabilityBindingsFile {
35    fn default() -> Self {
36        Self {
37            schema_version: "v1".to_string(),
38            generated_at: None,
39            bindings: default_spine_bindings(),
40        }
41    }
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct CapabilityToolAvailability {
46    pub provider: String,
47    pub tool_name: String,
48    #[serde(default)]
49    pub schema: Value,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct CapabilityResolveInput {
54    #[serde(default)]
55    pub workflow_id: Option<String>,
56    #[serde(default)]
57    pub required_capabilities: Vec<String>,
58    #[serde(default)]
59    pub optional_capabilities: Vec<String>,
60    #[serde(default)]
61    pub provider_preference: Vec<String>,
62    #[serde(default)]
63    pub available_tools: Vec<CapabilityToolAvailability>,
64}
65
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct CapabilityReadinessInput {
68    #[serde(default)]
69    pub workflow_id: Option<String>,
70    #[serde(default)]
71    pub required_capabilities: Vec<String>,
72    #[serde(default)]
73    pub optional_capabilities: Vec<String>,
74    #[serde(default)]
75    pub provider_preference: Vec<String>,
76    #[serde(default)]
77    pub available_tools: Vec<CapabilityToolAvailability>,
78    #[serde(default)]
79    pub allow_unbound: bool,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct CapabilityResolution {
84    pub capability_id: String,
85    pub provider: String,
86    pub tool_name: String,
87    pub binding_index: usize,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct CapabilityResolveOutput {
92    #[serde(default)]
93    pub resolved: Vec<CapabilityResolution>,
94    #[serde(default)]
95    pub missing_required: Vec<String>,
96    #[serde(default)]
97    pub missing_optional: Vec<String>,
98    #[serde(default)]
99    pub considered_bindings: usize,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct CapabilityBlockingIssue {
104    pub code: String,
105    pub message: String,
106    #[serde(default)]
107    pub capability_ids: Vec<String>,
108    #[serde(default)]
109    pub providers: Vec<String>,
110    #[serde(default)]
111    pub tools: Vec<String>,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct CapabilityReadinessOutput {
116    pub workflow_id: String,
117    pub runnable: bool,
118    #[serde(default)]
119    pub resolved: Vec<CapabilityResolution>,
120    #[serde(default)]
121    pub missing_required_capabilities: Vec<String>,
122    #[serde(default)]
123    pub unbound_capabilities: Vec<String>,
124    #[serde(default)]
125    pub missing_optional_capabilities: Vec<String>,
126    #[serde(default)]
127    pub missing_servers: Vec<String>,
128    #[serde(default)]
129    pub disconnected_servers: Vec<String>,
130    #[serde(default)]
131    pub auth_pending_tools: Vec<String>,
132    #[serde(default)]
133    pub missing_secret_refs: Vec<String>,
134    pub considered_bindings: usize,
135    #[serde(default)]
136    pub recommendations: Vec<String>,
137    #[serde(default)]
138    pub blocking_issues: Vec<CapabilityBlockingIssue>,
139}
140
141#[derive(Clone)]
142pub struct CapabilityResolver {
143    bindings_path: PathBuf,
144    lock: Arc<Mutex<()>>,
145}
146
147impl CapabilityResolver {
148    pub fn new(root: PathBuf) -> Self {
149        Self {
150            bindings_path: root.join("bindings").join("capability_bindings.json"),
151            lock: Arc::new(Mutex::new(())),
152        }
153    }
154
155    pub async fn list_bindings(&self) -> anyhow::Result<CapabilityBindingsFile> {
156        self.read_bindings().await
157    }
158
159    pub async fn set_bindings(&self, file: CapabilityBindingsFile) -> anyhow::Result<()> {
160        let _guard = self.lock.lock().await;
161        validate_bindings(&file)?;
162        if let Some(parent) = self.bindings_path.parent() {
163            tokio::fs::create_dir_all(parent).await?;
164        }
165        let payload = serde_json::to_string_pretty(&file)?;
166        tokio::fs::write(&self.bindings_path, format!("{}\n", payload)).await?;
167        Ok(())
168    }
169
170    pub async fn resolve(
171        &self,
172        input: CapabilityResolveInput,
173        discovered_tools: Vec<CapabilityToolAvailability>,
174    ) -> anyhow::Result<CapabilityResolveOutput> {
175        let bindings = self.read_bindings().await?;
176        validate_bindings(&bindings)?;
177        let preference = if input.provider_preference.is_empty() {
178            vec![
179                "composio".to_string(),
180                "arcade".to_string(),
181                "mcp".to_string(),
182                "custom".to_string(),
183            ]
184        } else {
185            input.provider_preference.clone()
186        };
187        let pref_rank = preference
188            .iter()
189            .enumerate()
190            .map(|(i, provider)| (provider.to_ascii_lowercase(), i))
191            .collect::<HashMap<_, _>>();
192        let available = if input.available_tools.is_empty() {
193            discovered_tools
194        } else {
195            input.available_tools.clone()
196        };
197        let available_set = available
198            .iter()
199            .map(|row| {
200                (
201                    row.provider.to_ascii_lowercase(),
202                    canonical_tool_name(&row.tool_name),
203                )
204            })
205            .collect::<HashSet<_>>();
206
207        let mut all_capabilities = input.required_capabilities.clone();
208        for cap in &input.optional_capabilities {
209            if !all_capabilities.contains(cap) {
210                all_capabilities.push(cap.clone());
211            }
212        }
213
214        let mut resolved = Vec::new();
215        let mut missing_required = Vec::new();
216        let mut missing_optional = Vec::new();
217
218        let by_capability = group_bindings(&bindings.bindings);
219        for capability_id in all_capabilities {
220            let Some(candidates) = by_capability.get(&capability_id) else {
221                if input.required_capabilities.contains(&capability_id) {
222                    missing_required.push(capability_id);
223                } else {
224                    missing_optional.push(capability_id);
225                }
226                continue;
227            };
228            let mut chosen: Option<(usize, &CapabilityBinding)> = None;
229            for (idx, candidate) in candidates {
230                let provider = candidate.provider.to_ascii_lowercase();
231                if !binding_matches_available(candidate, &provider, &available_set) {
232                    continue;
233                }
234                if let Some((chosen_idx, chosen_binding)) = chosen {
235                    let chosen_rank = pref_rank
236                        .get(&chosen_binding.provider.to_ascii_lowercase())
237                        .copied()
238                        .unwrap_or(usize::MAX);
239                    let this_rank = pref_rank.get(&provider).copied().unwrap_or(usize::MAX);
240                    if this_rank < chosen_rank || (this_rank == chosen_rank && *idx < chosen_idx) {
241                        chosen = Some((*idx, candidate));
242                    }
243                } else {
244                    chosen = Some((*idx, candidate));
245                }
246            }
247            if let Some((binding_index, binding)) = chosen {
248                resolved.push(CapabilityResolution {
249                    capability_id: capability_id.clone(),
250                    provider: binding.provider.clone(),
251                    tool_name: binding.tool_name.clone(),
252                    binding_index,
253                });
254            } else if input.required_capabilities.contains(&capability_id) {
255                missing_required.push(capability_id);
256            } else {
257                missing_optional.push(capability_id);
258            }
259        }
260
261        resolved.sort_by(|a, b| a.capability_id.cmp(&b.capability_id));
262        missing_required.sort();
263        missing_optional.sort();
264        Ok(CapabilityResolveOutput {
265            resolved,
266            missing_required,
267            missing_optional,
268            considered_bindings: bindings.bindings.len(),
269        })
270    }
271
272    pub async fn discover_from_runtime(
273        &self,
274        mcp_tools: Vec<tandem_runtime::McpRemoteTool>,
275        local_tools: Vec<tandem_types::ToolSchema>,
276    ) -> Vec<CapabilityToolAvailability> {
277        let mut out = Vec::new();
278        for tool in mcp_tools {
279            out.push(CapabilityToolAvailability {
280                provider: provider_from_tool_name(&tool.namespaced_name),
281                tool_name: tool.namespaced_name,
282                schema: tool.input_schema,
283            });
284        }
285        for tool in local_tools {
286            out.push(CapabilityToolAvailability {
287                provider: "custom".to_string(),
288                tool_name: tool.name,
289                schema: tool.input_schema,
290            });
291        }
292        out.sort_by(|a, b| {
293            a.provider
294                .cmp(&b.provider)
295                .then_with(|| a.tool_name.cmp(&b.tool_name))
296        });
297        out.dedup_by(|a, b| {
298            a.provider.eq_ignore_ascii_case(&b.provider)
299                && a.tool_name.eq_ignore_ascii_case(&b.tool_name)
300        });
301        out
302    }
303
304    pub fn missing_capability_error(
305        workflow_id: &str,
306        missing_capabilities: &[String],
307        available_capability_bindings: &HashMap<String, Vec<String>>,
308    ) -> Value {
309        let suggestions = missing_capabilities
310            .iter()
311            .map(|cap| {
312                let bindings = available_capability_bindings
313                    .get(cap)
314                    .cloned()
315                    .unwrap_or_default();
316                serde_json::json!({
317                    "capability_id": cap,
318                    "available_bindings": bindings,
319                })
320            })
321            .collect::<Vec<_>>();
322        serde_json::json!({
323            "code": "missing_capability",
324            "workflow_id": workflow_id,
325            "missing_capabilities": missing_capabilities,
326            "suggestions": suggestions,
327        })
328    }
329
330    async fn read_bindings(&self) -> anyhow::Result<CapabilityBindingsFile> {
331        if !self.bindings_path.exists() {
332            let default = CapabilityBindingsFile::default();
333            self.set_bindings(default.clone()).await?;
334            return Ok(default);
335        }
336        let raw = tokio::fs::read_to_string(&self.bindings_path).await?;
337        let parsed = serde_json::from_str::<CapabilityBindingsFile>(&raw)?;
338        Ok(parsed)
339    }
340}
341
342fn group_bindings(
343    bindings: &[CapabilityBinding],
344) -> BTreeMap<String, Vec<(usize, &CapabilityBinding)>> {
345    let mut map = BTreeMap::<String, Vec<(usize, &CapabilityBinding)>>::new();
346    for (idx, binding) in bindings.iter().enumerate() {
347        map.entry(binding.capability_id.clone())
348            .or_default()
349            .push((idx, binding));
350    }
351    map
352}
353
354pub fn classify_missing_required(
355    bindings: &CapabilityBindingsFile,
356    missing_required: &[String],
357) -> (Vec<String>, Vec<String>) {
358    let mut missing_capabilities = Vec::new();
359    let mut unbound_capabilities = Vec::new();
360    for capability_id in missing_required {
361        if bindings
362            .bindings
363            .iter()
364            .any(|binding| binding.capability_id == *capability_id)
365        {
366            unbound_capabilities.push(capability_id.clone());
367        } else {
368            missing_capabilities.push(capability_id.clone());
369        }
370    }
371    missing_capabilities.sort();
372    missing_capabilities.dedup();
373    unbound_capabilities.sort();
374    unbound_capabilities.dedup();
375    (missing_capabilities, unbound_capabilities)
376}
377
378pub fn providers_for_capability(
379    bindings: &CapabilityBindingsFile,
380    capability_id: &str,
381) -> Vec<String> {
382    let mut providers = bindings
383        .bindings
384        .iter()
385        .filter(|binding| binding.capability_id == capability_id)
386        .map(|binding| binding.provider.to_ascii_lowercase())
387        .collect::<Vec<_>>();
388    providers.sort();
389    providers.dedup();
390    providers
391}
392
393fn provider_from_tool_name(tool_name: &str) -> String {
394    let normalized = tool_name.to_ascii_lowercase();
395    if normalized.starts_with("mcp.composio.") {
396        return "composio".to_string();
397    }
398    if normalized.starts_with("mcp.arcade.") {
399        return "arcade".to_string();
400    }
401    if normalized.starts_with("mcp.") {
402        return "mcp".to_string();
403    }
404    "custom".to_string()
405}
406
407fn validate_bindings(file: &CapabilityBindingsFile) -> anyhow::Result<()> {
408    if file.schema_version.trim().is_empty() {
409        return Err(anyhow!("schema_version is required"));
410    }
411    for binding in &file.bindings {
412        if binding.capability_id.trim().is_empty() {
413            return Err(anyhow!("binding capability_id is required"));
414        }
415        if binding.provider.trim().is_empty() {
416            return Err(anyhow!("binding provider is required"));
417        }
418        if binding.tool_name.trim().is_empty() {
419            return Err(anyhow!("binding tool_name is required"));
420        }
421        for alias in &binding.tool_name_aliases {
422            if alias.trim().is_empty() {
423                return Err(anyhow!(
424                    "binding tool_name_aliases cannot contain empty values"
425                ));
426            }
427        }
428    }
429    Ok(())
430}
431
432fn default_spine_bindings() -> Vec<CapabilityBinding> {
433    vec![
434        make_binding(
435            "github.create_pull_request",
436            "composio",
437            "mcp.composio.github_create_pull_request",
438            &[
439                "mcp.composio.github.create_pull_request",
440                "mcp.composio.github_create_pr",
441            ],
442        ),
443        make_binding(
444            "github.create_pull_request",
445            "arcade",
446            "mcp.arcade.github_create_pull_request",
447            &["mcp.arcade.github.create_pull_request"],
448        ),
449        make_binding(
450            "github.create_pull_request",
451            "mcp",
452            "mcp.github.create_pull_request",
453            &["mcp.github_create_pull_request"],
454        ),
455        make_binding(
456            "github.create_issue",
457            "composio",
458            "mcp.composio.github_create_issue",
459            &["mcp.composio.github.create_issue"],
460        ),
461        make_binding(
462            "github.create_issue",
463            "arcade",
464            "mcp.arcade.github_create_issue",
465            &["mcp.arcade.github.create_issue"],
466        ),
467        make_binding(
468            "github.create_issue",
469            "mcp",
470            "mcp.github.create_issue",
471            &["mcp.github_create_issue"],
472        ),
473        make_binding(
474            "github.list_issues",
475            "composio",
476            "mcp.composio.github_list_issues",
477            &["mcp.composio.github.list_issues"],
478        ),
479        make_binding(
480            "github.get_issue",
481            "composio",
482            "mcp.composio.github_get_issue",
483            &["mcp.composio.github.get_issue"],
484        ),
485        make_binding(
486            "github.close_issue",
487            "composio",
488            "mcp.composio.github_close_issue",
489            &["mcp.composio.github.close_issue"],
490        ),
491        make_binding(
492            "github.create_branch",
493            "composio",
494            "mcp.composio.github_create_branch",
495            &["mcp.composio.github.create_branch"],
496        ),
497        make_binding(
498            "github.list_pull_requests",
499            "composio",
500            "mcp.composio.github_list_pull_requests",
501            &["mcp.composio.github.list_pull_requests"],
502        ),
503        make_binding(
504            "github.get_pull_request",
505            "composio",
506            "mcp.composio.github_get_pull_request",
507            &["mcp.composio.github.get_pull_request"],
508        ),
509        make_binding(
510            "github.comment_on_issue",
511            "composio",
512            "mcp.composio.github_create_issue_comment",
513            &["mcp.composio.github.comment_on_issue"],
514        ),
515        make_binding(
516            "github.comment_on_pull_request",
517            "composio",
518            "mcp.composio.github_create_pull_request_review_comment",
519            &["mcp.composio.github.comment_on_pull_request"],
520        ),
521        make_binding(
522            "github.list_repositories",
523            "composio",
524            "mcp.composio.github_list_repositories",
525            &["mcp.composio.github.list_repositories"],
526        ),
527        make_binding(
528            "slack.post_message",
529            "composio",
530            "mcp.composio.slack_post_message",
531            &["mcp.composio.slack.post_message"],
532        ),
533        make_binding(
534            "slack.post_message",
535            "arcade",
536            "mcp.arcade.slack_post_message",
537            &["mcp.arcade.slack.post_message"],
538        ),
539        make_binding(
540            "slack.reply_in_thread",
541            "composio",
542            "mcp.composio.slack_reply_to_thread",
543            &[
544                "mcp.composio.slack_reply_in_thread",
545                "mcp.composio.slack.reply_in_thread",
546            ],
547        ),
548        make_binding(
549            "slack.update_message",
550            "composio",
551            "mcp.composio.slack_update_message",
552            &["mcp.composio.slack.update_message"],
553        ),
554        make_binding(
555            "slack.list_channels",
556            "composio",
557            "mcp.composio.slack_list_channels",
558            &["mcp.composio.slack.list_channels"],
559        ),
560        make_binding(
561            "slack.get_channel_history",
562            "composio",
563            "mcp.composio.slack_get_channel_history",
564            &["mcp.composio.slack.get_channel_history"],
565        ),
566    ]
567}
568
569fn make_binding(
570    capability_id: &str,
571    provider: &str,
572    tool_name: &str,
573    aliases: &[&str],
574) -> CapabilityBinding {
575    CapabilityBinding {
576        capability_id: capability_id.to_string(),
577        provider: provider.to_string(),
578        tool_name: tool_name.to_string(),
579        tool_name_aliases: aliases.iter().map(|row| row.to_string()).collect(),
580        request_transform: None,
581        response_transform: None,
582        metadata: serde_json::json!({"spine": true}),
583    }
584}
585
586fn canonical_tool_name(name: &str) -> String {
587    let mut out = String::new();
588    let mut last_was_sep = false;
589    for ch in name.chars().flat_map(|c| c.to_lowercase()) {
590        if ch.is_ascii_alphanumeric() {
591            out.push(ch);
592            last_was_sep = false;
593        } else if !last_was_sep {
594            out.push('_');
595            last_was_sep = true;
596        }
597    }
598    out.trim_matches('_').to_string()
599}
600
601fn binding_matches_available(
602    binding: &CapabilityBinding,
603    provider: &str,
604    available_set: &HashSet<(String, String)>,
605) -> bool {
606    let mut names = Vec::with_capacity(1 + binding.tool_name_aliases.len());
607    names.push(binding.tool_name.as_str());
608    for alias in &binding.tool_name_aliases {
609        names.push(alias.as_str());
610    }
611    names.into_iter().any(|tool_name| {
612        available_set.contains(&(provider.to_string(), canonical_tool_name(tool_name)))
613    })
614}
615
616#[cfg(test)]
617mod tests {
618    use super::*;
619
620    #[tokio::test]
621    async fn resolve_prefers_composio_over_arcade_by_default() {
622        let root =
623            std::env::temp_dir().join(format!("tandem-cap-resolver-{}", uuid::Uuid::new_v4()));
624        let resolver = CapabilityResolver::new(root.clone());
625        let result = resolver
626            .resolve(
627                CapabilityResolveInput {
628                    workflow_id: Some("wf-1".to_string()),
629                    required_capabilities: vec!["github.create_pull_request".to_string()],
630                    optional_capabilities: vec![],
631                    provider_preference: vec![],
632                    available_tools: vec![
633                        CapabilityToolAvailability {
634                            provider: "arcade".to_string(),
635                            tool_name: "mcp.arcade.github_create_pull_request".to_string(),
636                            schema: Value::Null,
637                        },
638                        CapabilityToolAvailability {
639                            provider: "composio".to_string(),
640                            tool_name: "mcp.composio.github_create_pull_request".to_string(),
641                            schema: Value::Null,
642                        },
643                    ],
644                },
645                Vec::new(),
646            )
647            .await
648            .expect("resolve");
649        assert_eq!(result.missing_required, Vec::<String>::new());
650        assert_eq!(result.resolved.len(), 1);
651        assert_eq!(result.resolved[0].provider, "composio");
652        let _ = std::fs::remove_dir_all(root);
653    }
654
655    #[tokio::test]
656    async fn resolve_returns_missing_capability_when_unavailable() {
657        let root =
658            std::env::temp_dir().join(format!("tandem-cap-resolver-{}", uuid::Uuid::new_v4()));
659        let resolver = CapabilityResolver::new(root.clone());
660        let result = resolver
661            .resolve(
662                CapabilityResolveInput {
663                    workflow_id: Some("wf-2".to_string()),
664                    required_capabilities: vec!["github.create_pull_request".to_string()],
665                    optional_capabilities: vec![],
666                    provider_preference: vec!["arcade".to_string()],
667                    available_tools: vec![],
668                },
669                Vec::new(),
670            )
671            .await
672            .expect("resolve");
673        assert_eq!(
674            result.missing_required,
675            vec!["github.create_pull_request".to_string()]
676        );
677        let _ = std::fs::remove_dir_all(root);
678    }
679
680    #[tokio::test]
681    async fn resolve_matches_alias_with_name_normalization() {
682        let root =
683            std::env::temp_dir().join(format!("tandem-cap-resolver-{}", uuid::Uuid::new_v4()));
684        let resolver = CapabilityResolver::new(root.clone());
685        let result = resolver
686            .resolve(
687                CapabilityResolveInput {
688                    workflow_id: Some("wf-3".to_string()),
689                    required_capabilities: vec!["slack.reply_in_thread".to_string()],
690                    optional_capabilities: vec![],
691                    provider_preference: vec![],
692                    available_tools: vec![CapabilityToolAvailability {
693                        provider: "composio".to_string(),
694                        tool_name: "mcp.composio.slack.reply.in.thread".to_string(),
695                        schema: Value::Null,
696                    }],
697                },
698                Vec::new(),
699            )
700            .await
701            .expect("resolve");
702        assert_eq!(result.missing_required, Vec::<String>::new());
703        assert_eq!(result.resolved.len(), 1);
704        assert_eq!(result.resolved[0].capability_id, "slack.reply_in_thread");
705        let _ = std::fs::remove_dir_all(root);
706    }
707
708    #[tokio::test]
709    async fn resolve_honors_explicit_provider_preference() {
710        let root =
711            std::env::temp_dir().join(format!("tandem-cap-resolver-{}", uuid::Uuid::new_v4()));
712        let resolver = CapabilityResolver::new(root.clone());
713        let result = resolver
714            .resolve(
715                CapabilityResolveInput {
716                    workflow_id: Some("wf-4".to_string()),
717                    required_capabilities: vec!["github.create_pull_request".to_string()],
718                    optional_capabilities: vec![],
719                    provider_preference: vec!["arcade".to_string(), "composio".to_string()],
720                    available_tools: vec![
721                        CapabilityToolAvailability {
722                            provider: "composio".to_string(),
723                            tool_name: "mcp.composio.github_create_pull_request".to_string(),
724                            schema: Value::Null,
725                        },
726                        CapabilityToolAvailability {
727                            provider: "arcade".to_string(),
728                            tool_name: "mcp.arcade.github_create_pull_request".to_string(),
729                            schema: Value::Null,
730                        },
731                    ],
732                },
733                Vec::new(),
734            )
735            .await
736            .expect("resolve");
737        assert_eq!(result.missing_required, Vec::<String>::new());
738        assert_eq!(result.resolved.len(), 1);
739        assert_eq!(result.resolved[0].provider, "arcade");
740        let _ = std::fs::remove_dir_all(root);
741    }
742}