1use std::sync::Arc;
2
3use serde_json::Value;
4
5use crate::{LlmDefinition, Message, ToolCall};
6use anyhow::Result;
7
8#[async_trait::async_trait]
10pub trait OrchestratorTrait: Send + Sync {
11 async fn get_session_value(&self, session_id: &str, key: &str) -> Option<serde_json::Value>;
13
14 async fn set_session_value(
16 &self,
17 session_id: &str,
18 key: &str,
19 value: serde_json::Value,
20 ) -> anyhow::Result<()>;
21
22 async fn call_agent(
24 &self,
25 session_id: &str,
26 agent_name: &str,
27 task: &str,
28 ) -> anyhow::Result<String>;
29
30 async fn call_tool(
31 &self,
32 session_id: &str,
33 user_id: &str,
34 tool_call: &ToolCall,
35 ) -> anyhow::Result<serde_json::Value>;
36
37 async fn llm_execute(
38 &self,
39 llm_def: LlmDefinition,
40 llm_context: LLmContext,
41 ) -> Result<serde_json::Value, anyhow::Error>;
42}
43
44#[derive(Debug, Default, Clone)]
45pub struct LLmContext {
46 pub thread_id: Option<String>,
47 pub task_id: Option<String>,
48 pub run_id: Option<String>,
49 pub label: Option<String>,
50 pub messages: Vec<Message>,
51}
52pub struct OrchestratorRef {
54 inner: std::sync::RwLock<Option<Arc<dyn OrchestratorTrait>>>,
55}
56
57impl Default for OrchestratorRef {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl OrchestratorRef {
64 pub fn new() -> Self {
66 Self {
67 inner: std::sync::RwLock::new(None),
68 }
69 }
70
71 pub fn set_orchestrator(&self, orchestrator: Arc<dyn OrchestratorTrait>) {
73 let mut inner = self.inner.write().unwrap();
74 *inner = Some(orchestrator);
75 }
76
77 fn get_orchestrator(&self) -> Option<Arc<dyn OrchestratorTrait>> {
79 let inner = self.inner.read().unwrap();
80 inner.clone()
81 }
82}
83
84#[async_trait::async_trait]
85impl OrchestratorTrait for OrchestratorRef {
86 async fn get_session_value(&self, session_id: &str, key: &str) -> Option<serde_json::Value> {
88 if let Some(orchestrator) = self.get_orchestrator() {
89 orchestrator.get_session_value(session_id, key).await
90 } else {
91 None
93 }
94 }
95
96 async fn set_session_value(
98 &self,
99 session_id: &str,
100 key: &str,
101 value: serde_json::Value,
102 ) -> Result<()> {
103 if let Some(orchestrator) = self.get_orchestrator() {
104 orchestrator.set_session_value(session_id, key, value).await
105 } else {
106 Err(anyhow::anyhow!("No orchestrator available"))
107 }
108 }
109
110 async fn call_agent(&self, session_id: &str, agent_name: &str, task: &str) -> Result<String> {
112 if let Some(orchestrator) = self.get_orchestrator() {
113 orchestrator.call_agent(session_id, agent_name, task).await
114 } else {
115 Err(anyhow::anyhow!("No orchestrator available"))
116 }
117 }
118
119 async fn call_tool(
120 &self,
121 session_id: &str,
122 user_id: &str,
123 tool_call: &ToolCall,
124 ) -> Result<serde_json::Value> {
125 if let Some(orchestrator) = self.get_orchestrator() {
126 orchestrator.call_tool(session_id, user_id, tool_call).await
127 } else {
128 Err(anyhow::anyhow!("No orchestrator available"))
129 }
130 }
131
132 async fn llm_execute(
133 &self,
134 llm_def: crate::LlmDefinition,
135 llm_context: LLmContext,
136 ) -> Result<serde_json::Value, anyhow::Error> {
137 if let Some(orchestrator) = self.get_orchestrator() {
138 orchestrator.llm_execute(llm_def, llm_context).await
139 } else {
140 Err(anyhow::anyhow!("No orchestrator available"))
141 }
142 }
143}
144
145pub struct MockOrchestrator;
146
147#[async_trait::async_trait]
148impl OrchestratorTrait for MockOrchestrator {
149 async fn get_session_value(&self, session_id: &str, key: &str) -> Option<serde_json::Value> {
150 Some(Value::String(format!("{}:{}", session_id, key)))
151 }
152
153 async fn set_session_value(
154 &self,
155 _session_id: &str,
156 _key: &str,
157 _value: serde_json::Value,
158 ) -> Result<()> {
159 Ok(())
160 }
161
162 async fn call_agent(&self, session_id: &str, agent_name: &str, task: &str) -> Result<String> {
163 Ok(format!(
164 "mock response for agent: {}, task {}, session_id {}",
165 agent_name, task, session_id
166 ))
167 }
168
169 async fn call_tool(
170 &self,
171 session_id: &str,
172 user_id: &str,
173 tool_call: &ToolCall,
174 ) -> Result<serde_json::Value> {
175 Ok(Value::String(format!(
176 "mock response for tool: {}, session_id {}, user_id: {}, tool_call {}",
177 tool_call.tool_name, session_id, user_id, tool_call.tool_call_id
178 )))
179 }
180
181 async fn llm_execute(
182 &self,
183 llm_def: crate::LlmDefinition,
184 llm_context: LLmContext,
185 ) -> Result<serde_json::Value, anyhow::Error> {
186 Ok(Value::String(format!(
187 "mock response for llm_execute, llm_def {:?}, llm_context: {:?}",
188 llm_def, llm_context
189 )))
190 }
191}