claude_agent_sdk/orchestration/
context.rs1use crate::orchestration::agent::AgentOutput;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::Duration;
10use tokio::sync::RwLock;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ExecutionConfig {
15 pub timeout: Duration,
17
18 pub max_retries: usize,
20
21 pub parallel_limit: usize,
23
24 pub enable_logging: bool,
26
27 pub enable_tracing: bool,
29}
30
31impl Default for ExecutionConfig {
32 fn default() -> Self {
33 Self {
34 timeout: Duration::from_secs(300), max_retries: 3,
36 parallel_limit: 10,
37 enable_logging: true,
38 enable_tracing: true,
39 }
40 }
41}
42
43impl ExecutionConfig {
44 pub fn new() -> Self {
46 Self::default()
47 }
48
49 pub fn with_timeout(mut self, timeout: Duration) -> Self {
51 self.timeout = timeout;
52 self
53 }
54
55 pub fn with_max_retries(mut self, max_retries: usize) -> Self {
57 self.max_retries = max_retries;
58 self
59 }
60
61 pub fn with_parallel_limit(mut self, parallel_limit: usize) -> Self {
63 self.parallel_limit = parallel_limit;
64 self
65 }
66
67 pub fn with_logging(mut self, enable: bool) -> Self {
69 self.enable_logging = enable;
70 self
71 }
72
73 pub fn with_tracing(mut self, enable: bool) -> Self {
75 self.enable_tracing = enable;
76 self
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct ExecutionTrace {
83 pub start_time: chrono::DateTime<chrono::Utc>,
85
86 pub end_time: Option<chrono::DateTime<chrono::Utc>>,
88
89 pub agent_executions: Vec<AgentExecution>,
91
92 pub duration_ms: Option<u64>,
94}
95
96impl ExecutionTrace {
97 pub fn new() -> Self {
99 Self {
100 start_time: chrono::Utc::now(),
101 end_time: None,
102 agent_executions: Vec::new(),
103 duration_ms: None,
104 }
105 }
106
107 pub fn add_execution(&mut self, execution: AgentExecution) {
109 self.agent_executions.push(execution);
110 }
111
112 pub fn complete(&mut self) {
114 self.end_time = Some(chrono::Utc::now());
115 self.duration_ms = Some(
116 self.end_time
117 .unwrap()
118 .signed_duration_since(self.start_time)
119 .num_milliseconds() as u64,
120 );
121 }
122
123 pub fn duration(&self) -> Option<chrono::Duration> {
125 self.duration_ms
126 .map(|ms| chrono::Duration::milliseconds(ms as i64))
127 }
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct AgentExecution {
133 pub agent_name: String,
135
136 pub start_time: chrono::DateTime<chrono::Utc>,
138
139 pub end_time: Option<chrono::DateTime<chrono::Utc>>,
141
142 pub input: crate::orchestration::agent::AgentInput,
144
145 pub output: Option<AgentOutput>,
147
148 pub success: bool,
150
151 pub error: Option<String>,
153
154 pub duration_ms: Option<u64>,
156}
157
158impl AgentExecution {
159 pub fn new(
161 agent_name: impl Into<String>,
162 input: crate::orchestration::agent::AgentInput,
163 ) -> Self {
164 Self {
165 agent_name: agent_name.into(),
166 start_time: chrono::Utc::now(),
167 end_time: None,
168 input,
169 output: None,
170 success: false,
171 error: None,
172 duration_ms: None,
173 }
174 }
175
176 pub fn succeed(&mut self, output: AgentOutput) {
178 self.success = true;
179 self.output = Some(output);
180 self.end_time = Some(chrono::Utc::now());
181 self.duration_ms = Some(
182 self.end_time
183 .unwrap()
184 .signed_duration_since(self.start_time)
185 .num_milliseconds() as u64,
186 );
187 }
188
189 pub fn fail(&mut self, error: impl Into<String>) {
191 self.success = false;
192 self.error = Some(error.into());
193 self.end_time = Some(chrono::Utc::now());
194 self.duration_ms = Some(
195 self.end_time
196 .unwrap()
197 .signed_duration_since(self.start_time)
198 .num_milliseconds() as u64,
199 );
200 }
201}
202
203pub struct ExecutionContext {
205 config: ExecutionConfig,
207
208 state: RwLock<HashMap<String, serde_json::Value>>,
210
211 trace: RwLock<ExecutionTrace>,
213}
214
215impl Clone for ExecutionContext {
216 fn clone(&self) -> Self {
217 Self {
219 config: self.config.clone(),
220 state: RwLock::new(HashMap::new()),
221 trace: RwLock::new(ExecutionTrace::new()),
222 }
223 }
224}
225
226impl ExecutionContext {
227 pub fn new(config: ExecutionConfig) -> Self {
229 Self {
230 config,
231 state: RwLock::new(HashMap::new()),
232 trace: RwLock::new(ExecutionTrace::new()),
233 }
234 }
235
236 pub fn config(&self) -> &ExecutionConfig {
238 &self.config
239 }
240
241 pub async fn get_state(&self, key: &str) -> Option<serde_json::Value> {
243 let state = self.state.read().await;
244 state.get(key).cloned()
245 }
246
247 pub async fn set_state(&self, key: impl Into<String>, value: serde_json::Value) {
249 let mut state = self.state.write().await;
250 state.insert(key.into(), value);
251 }
252
253 pub async fn remove_state(&self, key: &str) -> Option<serde_json::Value> {
255 let mut state = self.state.write().await;
256 state.remove(key)
257 }
258
259 pub async fn clear_state(&self) {
261 let mut state = self.state.write().await;
262 state.clear();
263 }
264
265 pub async fn get_trace(&self) -> ExecutionTrace {
267 self.trace.read().await.clone()
268 }
269
270 pub async fn add_execution(&self, execution: AgentExecution) {
272 let mut trace = self.trace.write().await;
273 trace.add_execution(execution);
274 }
275
276 pub async fn complete_trace(&self) {
278 let mut trace = self.trace.write().await;
279 trace.complete();
280 }
281
282 pub fn is_logging_enabled(&self) -> bool {
284 self.config.enable_logging
285 }
286
287 pub fn is_tracing_enabled(&self) -> bool {
289 self.config.enable_tracing
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296
297 #[test]
298 fn test_execution_config() {
299 let config = ExecutionConfig::new()
300 .with_timeout(Duration::from_secs(60))
301 .with_max_retries(5)
302 .with_parallel_limit(20)
303 .with_logging(false)
304 .with_tracing(false);
305
306 assert_eq!(config.timeout.as_secs(), 60);
307 assert_eq!(config.max_retries, 5);
308 assert_eq!(config.parallel_limit, 20);
309 assert!(!config.enable_logging);
310 assert!(!config.enable_tracing);
311 }
312
313 #[tokio::test]
314 async fn test_execution_context() {
315 let config = ExecutionConfig::new();
316 let ctx = ExecutionContext::new(config);
317
318 ctx.set_state("key1", serde_json::json!("value1")).await;
320 assert_eq!(
321 ctx.get_state("key1").await,
322 Some(serde_json::json!("value1"))
323 );
324
325 ctx.set_state("key2", serde_json::json!(42)).await;
326 assert_eq!(ctx.get_state("key2").await, Some(serde_json::json!(42)));
327
328 assert_eq!(
329 ctx.remove_state("key1").await,
330 Some(serde_json::json!("value1"))
331 );
332 assert!(ctx.get_state("key1").await.is_none());
333
334 ctx.clear_state().await;
335 assert!(ctx.get_state("key2").await.is_none());
336 }
337
338 #[test]
339 fn test_execution_trace() {
340 let mut trace = ExecutionTrace::new();
341 assert!(trace.end_time.is_none());
342 assert!(trace.duration_ms.is_none());
343
344 trace.complete();
345 assert!(trace.end_time.is_some());
346 assert!(trace.duration_ms.is_some());
347 }
348
349 #[test]
350 fn test_agent_execution() {
351 let input = crate::orchestration::agent::AgentInput::new("test");
352 let mut exec = AgentExecution::new("TestAgent", input);
353
354 assert!(!exec.success);
355 assert!(exec.output.is_none());
356 assert!(exec.end_time.is_none());
357
358 let output = AgentOutput::new("result").with_confidence(0.9);
359 exec.succeed(output);
360
361 assert!(exec.success);
362 assert!(exec.output.is_some());
363 assert!(exec.end_time.is_some());
364 assert!(exec.duration_ms.is_some());
365 }
366}