Skip to main content

distri_types/
orchestrator.rs

1use std::sync::Arc;
2
3use serde_json::Value;
4
5use crate::{LlmDefinition, Message, ToolCall};
6use anyhow::Result;
7
8/// Trait for workflow runtime functions (session storage, agent calls, etc.)
9#[async_trait::async_trait]
10pub trait OrchestratorTrait: Send + Sync {
11    /// Get a session value for a specific session
12    async fn get_session_value(&self, session_id: &str, key: &str) -> Option<serde_json::Value>;
13
14    /// Set a session value for a specific session
15    async fn set_session_value(
16        &self,
17        session_id: &str,
18        key: &str,
19        value: serde_json::Value,
20    ) -> anyhow::Result<()>;
21
22    /// Call an agent via the orchestrator
23    async fn call_agent(
24        &self,
25        session_id: &str,
26        agent_name: &str,
27        task: &str,
28    ) -> anyhow::Result<String>;
29
30    async fn call_tool(
31        &self,
32        session_id: &str,
33        user_id: &str,
34        tool_call: &ToolCall,
35    ) -> anyhow::Result<serde_json::Value>;
36
37    async fn llm_execute(
38        &self,
39        llm_def: LlmDefinition,
40        llm_context: LLmContext,
41    ) -> Result<serde_json::Value, anyhow::Error>;
42}
43
44#[derive(Debug, Default, Clone)]
45pub struct LLmContext {
46    pub thread_id: Option<String>,
47    pub task_id: Option<String>,
48    pub run_id: Option<String>,
49    pub label: Option<String>,
50    pub messages: Vec<Message>,
51}
52/// Reference wrapper for orchestrator that allows late binding
53pub struct OrchestratorRef {
54    inner: std::sync::RwLock<Option<Arc<dyn OrchestratorTrait>>>,
55}
56
57impl OrchestratorRef {
58    /// Create a new orchestrator reference without an actual orchestrator
59    pub fn new() -> Self {
60        Self {
61            inner: std::sync::RwLock::new(None),
62        }
63    }
64
65    /// Set the actual orchestrator (called after orchestrator is created)
66    pub fn set_orchestrator(&self, orchestrator: Arc<dyn OrchestratorTrait>) {
67        let mut inner = self.inner.write().unwrap();
68        *inner = Some(orchestrator);
69    }
70
71    /// Get a clone of the orchestrator if available
72    fn get_orchestrator(&self) -> Option<Arc<dyn OrchestratorTrait>> {
73        let inner = self.inner.read().unwrap();
74        inner.clone()
75    }
76}
77
78#[async_trait::async_trait]
79impl OrchestratorTrait for OrchestratorRef {
80    /// Get a session value for a specific session
81    async fn get_session_value(&self, session_id: &str, key: &str) -> Option<serde_json::Value> {
82        if let Some(orchestrator) = self.get_orchestrator() {
83            orchestrator.get_session_value(session_id, key).await
84        } else {
85            // If no orchestrator is set, return None
86            None
87        }
88    }
89
90    /// Set a session value for a specific session
91    async fn set_session_value(
92        &self,
93        session_id: &str,
94        key: &str,
95        value: serde_json::Value,
96    ) -> Result<()> {
97        if let Some(orchestrator) = self.get_orchestrator() {
98            orchestrator.set_session_value(session_id, key, value).await
99        } else {
100            Err(anyhow::anyhow!("No orchestrator available"))
101        }
102    }
103
104    /// Call an agent via the orchestrator
105    async fn call_agent(&self, session_id: &str, agent_name: &str, task: &str) -> Result<String> {
106        if let Some(orchestrator) = self.get_orchestrator() {
107            orchestrator.call_agent(session_id, agent_name, task).await
108        } else {
109            Err(anyhow::anyhow!("No orchestrator available"))
110        }
111    }
112
113    async fn call_tool(
114        &self,
115        session_id: &str,
116        user_id: &str,
117        tool_call: &ToolCall,
118    ) -> Result<serde_json::Value> {
119        if let Some(orchestrator) = self.get_orchestrator() {
120            orchestrator.call_tool(session_id, user_id, tool_call).await
121        } else {
122            Err(anyhow::anyhow!("No orchestrator available"))
123        }
124    }
125
126    async fn llm_execute(
127        &self,
128        llm_def: crate::LlmDefinition,
129        llm_context: LLmContext,
130    ) -> Result<serde_json::Value, anyhow::Error> {
131        if let Some(orchestrator) = self.get_orchestrator() {
132            orchestrator.llm_execute(llm_def, llm_context).await
133        } else {
134            Err(anyhow::anyhow!("No orchestrator available"))
135        }
136    }
137}
138
139pub struct MockOrchestrator;
140
141#[async_trait::async_trait]
142impl OrchestratorTrait for MockOrchestrator {
143    async fn get_session_value(&self, session_id: &str, key: &str) -> Option<serde_json::Value> {
144        Some(Value::String(format!("{}:{}", session_id, key)))
145    }
146
147    async fn set_session_value(
148        &self,
149        _session_id: &str,
150        _key: &str,
151        _value: serde_json::Value,
152    ) -> Result<()> {
153        Ok(())
154    }
155
156    async fn call_agent(&self, session_id: &str, agent_name: &str, task: &str) -> Result<String> {
157        Ok(format!(
158            "mock response for agent: {}, task {}, session_id {}",
159            agent_name, task, session_id
160        ))
161    }
162
163    async fn call_tool(
164        &self,
165        session_id: &str,
166        user_id: &str,
167        tool_call: &ToolCall,
168    ) -> Result<serde_json::Value> {
169        Ok(Value::String(format!(
170            "mock response for tool: {},  session_id {}, user_id: {}, tool_call {}",
171            tool_call.tool_name, session_id, user_id, tool_call.tool_call_id
172        )))
173    }
174
175    async fn llm_execute(
176        &self,
177        llm_def: crate::LlmDefinition,
178        llm_context: LLmContext,
179    ) -> Result<serde_json::Value, anyhow::Error> {
180        Ok(Value::String(format!(
181            "mock response for llm_execute, llm_def {:?}, llm_context: {:?}",
182            llm_def, llm_context
183        )))
184    }
185}