1use std::sync::Arc;
2
3use serde_json::Value;
4
5use crate::{LlmDefinition, Message, ToolCall};
6use anyhow::Result;
7
8#[async_trait::async_trait]
10pub trait OrchestratorTrait: Send + Sync {
11 async fn get_session_value(&self, session_id: &str, key: &str) -> Option<serde_json::Value>;
13
14 async fn set_session_value(
16 &self,
17 session_id: &str,
18 key: &str,
19 value: serde_json::Value,
20 ) -> anyhow::Result<()>;
21
22 async fn call_agent(
24 &self,
25 session_id: &str,
26 agent_name: &str,
27 task: &str,
28 ) -> anyhow::Result<String>;
29
30 async fn call_tool(
31 &self,
32 session_id: &str,
33 user_id: &str,
34 tool_call: &ToolCall,
35 ) -> anyhow::Result<serde_json::Value>;
36
37 async fn llm_execute(
38 &self,
39 llm_def: LlmDefinition,
40 llm_context: LLmContext,
41 ) -> Result<serde_json::Value, anyhow::Error>;
42}
43
44#[derive(Debug, Default, Clone)]
45pub struct LLmContext {
46 pub thread_id: Option<String>,
47 pub task_id: Option<String>,
48 pub run_id: Option<String>,
49 pub label: Option<String>,
50 pub messages: Vec<Message>,
51}
52pub struct OrchestratorRef {
54 inner: std::sync::RwLock<Option<Arc<dyn OrchestratorTrait>>>,
55}
56
57impl OrchestratorRef {
58 pub fn new() -> Self {
60 Self {
61 inner: std::sync::RwLock::new(None),
62 }
63 }
64
65 pub fn set_orchestrator(&self, orchestrator: Arc<dyn OrchestratorTrait>) {
67 let mut inner = self.inner.write().unwrap();
68 *inner = Some(orchestrator);
69 }
70
71 fn get_orchestrator(&self) -> Option<Arc<dyn OrchestratorTrait>> {
73 let inner = self.inner.read().unwrap();
74 inner.clone()
75 }
76}
77
78#[async_trait::async_trait]
79impl OrchestratorTrait for OrchestratorRef {
80 async fn get_session_value(&self, session_id: &str, key: &str) -> Option<serde_json::Value> {
82 if let Some(orchestrator) = self.get_orchestrator() {
83 orchestrator.get_session_value(session_id, key).await
84 } else {
85 None
87 }
88 }
89
90 async fn set_session_value(
92 &self,
93 session_id: &str,
94 key: &str,
95 value: serde_json::Value,
96 ) -> Result<()> {
97 if let Some(orchestrator) = self.get_orchestrator() {
98 orchestrator.set_session_value(session_id, key, value).await
99 } else {
100 Err(anyhow::anyhow!("No orchestrator available"))
101 }
102 }
103
104 async fn call_agent(&self, session_id: &str, agent_name: &str, task: &str) -> Result<String> {
106 if let Some(orchestrator) = self.get_orchestrator() {
107 orchestrator.call_agent(session_id, agent_name, task).await
108 } else {
109 Err(anyhow::anyhow!("No orchestrator available"))
110 }
111 }
112
113 async fn call_tool(
114 &self,
115 session_id: &str,
116 user_id: &str,
117 tool_call: &ToolCall,
118 ) -> Result<serde_json::Value> {
119 if let Some(orchestrator) = self.get_orchestrator() {
120 orchestrator.call_tool(session_id, user_id, tool_call).await
121 } else {
122 Err(anyhow::anyhow!("No orchestrator available"))
123 }
124 }
125
126 async fn llm_execute(
127 &self,
128 llm_def: crate::LlmDefinition,
129 llm_context: LLmContext,
130 ) -> Result<serde_json::Value, anyhow::Error> {
131 if let Some(orchestrator) = self.get_orchestrator() {
132 orchestrator.llm_execute(llm_def, llm_context).await
133 } else {
134 Err(anyhow::anyhow!("No orchestrator available"))
135 }
136 }
137}
138
139pub struct MockOrchestrator;
140
141#[async_trait::async_trait]
142impl OrchestratorTrait for MockOrchestrator {
143 async fn get_session_value(&self, session_id: &str, key: &str) -> Option<serde_json::Value> {
144 Some(Value::String(format!("{}:{}", session_id, key)))
145 }
146
147 async fn set_session_value(
148 &self,
149 _session_id: &str,
150 _key: &str,
151 _value: serde_json::Value,
152 ) -> Result<()> {
153 Ok(())
154 }
155
156 async fn call_agent(&self, session_id: &str, agent_name: &str, task: &str) -> Result<String> {
157 Ok(format!(
158 "mock response for agent: {}, task {}, session_id {}",
159 agent_name, task, session_id
160 ))
161 }
162
163 async fn call_tool(
164 &self,
165 session_id: &str,
166 user_id: &str,
167 tool_call: &ToolCall,
168 ) -> Result<serde_json::Value> {
169 Ok(Value::String(format!(
170 "mock response for tool: {}, session_id {}, user_id: {}, tool_call {}",
171 tool_call.tool_name, session_id, user_id, tool_call.tool_call_id
172 )))
173 }
174
175 async fn llm_execute(
176 &self,
177 llm_def: crate::LlmDefinition,
178 llm_context: LLmContext,
179 ) -> Result<serde_json::Value, anyhow::Error> {
180 Ok(Value::String(format!(
181 "mock response for llm_execute, llm_def {:?}, llm_context: {:?}",
182 llm_def, llm_context
183 )))
184 }
185}