Skip to main content

distri_types/
orchestrator.rs

1use std::sync::Arc;
2
3use serde_json::Value;
4
5use crate::{LlmDefinition, Message, ToolCall};
6use anyhow::Result;
7
8/// Trait for workflow runtime functions (session storage, agent calls, etc.)
9#[async_trait::async_trait]
10pub trait OrchestratorTrait: Send + Sync {
11    /// Get a session value for a specific session
12    async fn get_session_value(&self, session_id: &str, key: &str) -> Option<serde_json::Value>;
13
14    /// Set a session value for a specific session
15    async fn set_session_value(
16        &self,
17        session_id: &str,
18        key: &str,
19        value: serde_json::Value,
20    ) -> anyhow::Result<()>;
21
22    /// Call an agent via the orchestrator
23    async fn call_agent(
24        &self,
25        session_id: &str,
26        agent_name: &str,
27        task: &str,
28    ) -> anyhow::Result<String>;
29
30    async fn call_tool(
31        &self,
32        session_id: &str,
33        user_id: &str,
34        tool_call: &ToolCall,
35    ) -> anyhow::Result<serde_json::Value>;
36
37    async fn llm_execute(
38        &self,
39        llm_def: LlmDefinition,
40        llm_context: LLmContext,
41    ) -> Result<serde_json::Value, anyhow::Error>;
42}
43
44#[derive(Debug, Default, Clone)]
45pub struct LLmContext {
46    pub thread_id: Option<String>,
47    pub task_id: Option<String>,
48    pub run_id: Option<String>,
49    pub label: Option<String>,
50    pub messages: Vec<Message>,
51}
52/// Reference wrapper for orchestrator that allows late binding
53pub struct OrchestratorRef {
54    inner: std::sync::RwLock<Option<Arc<dyn OrchestratorTrait>>>,
55}
56
57impl Default for OrchestratorRef {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63impl OrchestratorRef {
64    /// Create a new orchestrator reference without an actual orchestrator
65    pub fn new() -> Self {
66        Self {
67            inner: std::sync::RwLock::new(None),
68        }
69    }
70
71    /// Set the actual orchestrator (called after orchestrator is created)
72    pub fn set_orchestrator(&self, orchestrator: Arc<dyn OrchestratorTrait>) {
73        let mut inner = self.inner.write().unwrap();
74        *inner = Some(orchestrator);
75    }
76
77    /// Get a clone of the orchestrator if available
78    fn get_orchestrator(&self) -> Option<Arc<dyn OrchestratorTrait>> {
79        let inner = self.inner.read().unwrap();
80        inner.clone()
81    }
82}
83
84#[async_trait::async_trait]
85impl OrchestratorTrait for OrchestratorRef {
86    /// Get a session value for a specific session
87    async fn get_session_value(&self, session_id: &str, key: &str) -> Option<serde_json::Value> {
88        if let Some(orchestrator) = self.get_orchestrator() {
89            orchestrator.get_session_value(session_id, key).await
90        } else {
91            // If no orchestrator is set, return None
92            None
93        }
94    }
95
96    /// Set a session value for a specific session
97    async fn set_session_value(
98        &self,
99        session_id: &str,
100        key: &str,
101        value: serde_json::Value,
102    ) -> Result<()> {
103        if let Some(orchestrator) = self.get_orchestrator() {
104            orchestrator.set_session_value(session_id, key, value).await
105        } else {
106            Err(anyhow::anyhow!("No orchestrator available"))
107        }
108    }
109
110    /// Call an agent via the orchestrator
111    async fn call_agent(&self, session_id: &str, agent_name: &str, task: &str) -> Result<String> {
112        if let Some(orchestrator) = self.get_orchestrator() {
113            orchestrator.call_agent(session_id, agent_name, task).await
114        } else {
115            Err(anyhow::anyhow!("No orchestrator available"))
116        }
117    }
118
119    async fn call_tool(
120        &self,
121        session_id: &str,
122        user_id: &str,
123        tool_call: &ToolCall,
124    ) -> Result<serde_json::Value> {
125        if let Some(orchestrator) = self.get_orchestrator() {
126            orchestrator.call_tool(session_id, user_id, tool_call).await
127        } else {
128            Err(anyhow::anyhow!("No orchestrator available"))
129        }
130    }
131
132    async fn llm_execute(
133        &self,
134        llm_def: crate::LlmDefinition,
135        llm_context: LLmContext,
136    ) -> Result<serde_json::Value, anyhow::Error> {
137        if let Some(orchestrator) = self.get_orchestrator() {
138            orchestrator.llm_execute(llm_def, llm_context).await
139        } else {
140            Err(anyhow::anyhow!("No orchestrator available"))
141        }
142    }
143}
144
145pub struct MockOrchestrator;
146
147#[async_trait::async_trait]
148impl OrchestratorTrait for MockOrchestrator {
149    async fn get_session_value(&self, session_id: &str, key: &str) -> Option<serde_json::Value> {
150        Some(Value::String(format!("{}:{}", session_id, key)))
151    }
152
153    async fn set_session_value(
154        &self,
155        _session_id: &str,
156        _key: &str,
157        _value: serde_json::Value,
158    ) -> Result<()> {
159        Ok(())
160    }
161
162    async fn call_agent(&self, session_id: &str, agent_name: &str, task: &str) -> Result<String> {
163        Ok(format!(
164            "mock response for agent: {}, task {}, session_id {}",
165            agent_name, task, session_id
166        ))
167    }
168
169    async fn call_tool(
170        &self,
171        session_id: &str,
172        user_id: &str,
173        tool_call: &ToolCall,
174    ) -> Result<serde_json::Value> {
175        Ok(Value::String(format!(
176            "mock response for tool: {},  session_id {}, user_id: {}, tool_call {}",
177            tool_call.tool_name, session_id, user_id, tool_call.tool_call_id
178        )))
179    }
180
181    async fn llm_execute(
182        &self,
183        llm_def: crate::LlmDefinition,
184        llm_context: LLmContext,
185    ) -> Result<serde_json::Value, anyhow::Error> {
186        Ok(Value::String(format!(
187            "mock response for llm_execute, llm_def {:?}, llm_context: {:?}",
188            llm_def, llm_context
189        )))
190    }
191}