1use serde::{Deserialize, Serialize};
7use std::time::Duration;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
18#[serde(tag = "type", rename_all = "snake_case")]
19pub enum GenerationEvent {
20 Started {
22 tool_name: String,
24 total_tools: usize,
26 current_index: usize,
28 },
29
30 Thinking {
32 thought: String,
34 },
35
36 Searching {
38 query: String,
40 },
41
42 SearchResult {
44 tools: Vec<String>,
46 count: usize,
48 },
49
50 Example {
52 example: GeneratedExample,
54 },
55
56 Validation {
58 valid: bool,
60 errors: Vec<String>,
62 example_index: usize,
64 },
65
66 Progress {
68 current: usize,
70 total: usize,
72 percent: f32,
74 message: Option<String>,
76 },
77
78 ToolCompleted {
80 tool_name: String,
82 examples_generated: usize,
84 valid_examples: usize,
86 duration_ms: u64,
88 },
89
90 Completed {
92 total_examples: usize,
94 total_valid: usize,
96 total_tools: usize,
98 duration_ms: u64,
100 },
101
102 Error {
104 message: String,
106 recoverable: bool,
108 tool_name: Option<String>,
110 },
111
112 AgentStep {
114 step: AgentStep,
116 },
117}
118
119impl GenerationEvent {
120 pub fn started(tool_name: impl Into<String>, total_tools: usize, current_index: usize) -> Self {
122 Self::Started {
123 tool_name: tool_name.into(),
124 total_tools,
125 current_index,
126 }
127 }
128
129 pub fn thinking(thought: impl Into<String>) -> Self {
131 Self::Thinking {
132 thought: thought.into(),
133 }
134 }
135
136 pub fn progress(current: usize, total: usize, message: Option<String>) -> Self {
138 let percent = if total > 0 {
139 (current as f32 / total as f32) * 100.0
140 } else {
141 0.0
142 };
143 Self::Progress {
144 current,
145 total,
146 percent,
147 message,
148 }
149 }
150
151 pub fn error(message: impl Into<String>, recoverable: bool) -> Self {
153 Self::Error {
154 message: message.into(),
155 recoverable,
156 tool_name: None,
157 }
158 }
159
160 pub fn tool_error(message: impl Into<String>, tool_name: impl Into<String>, recoverable: bool) -> Self {
162 Self::Error {
163 message: message.into(),
164 recoverable,
165 tool_name: Some(tool_name.into()),
166 }
167 }
168
169 pub fn completed(total_examples: usize, total_valid: usize, total_tools: usize, duration: Duration) -> Self {
171 Self::Completed {
172 total_examples,
173 total_valid,
174 total_tools,
175 duration_ms: duration.as_millis() as u64,
176 }
177 }
178
179 pub fn to_sse_data(&self) -> String {
181 format!("data: {}\n\n", serde_json::to_string(self).unwrap_or_default())
182 }
183
184 pub fn to_sse(&self) -> String {
186 let event_type = match self {
187 Self::Started { .. } => "started",
188 Self::Thinking { .. } => "thinking",
189 Self::Searching { .. } => "searching",
190 Self::SearchResult { .. } => "search_result",
191 Self::Example { .. } => "example",
192 Self::Validation { .. } => "validation",
193 Self::Progress { .. } => "progress",
194 Self::ToolCompleted { .. } => "tool_completed",
195 Self::Completed { .. } => "completed",
196 Self::Error { .. } => "error",
197 Self::AgentStep { .. } => "agent_step",
198 };
199 format!(
200 "event: {}\ndata: {}\n\n",
201 event_type,
202 serde_json::to_string(self).unwrap_or_default()
203 )
204 }
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize)]
213pub struct GeneratedExample {
214 pub command: String,
216
217 pub explanation: String,
219
220 #[serde(default)]
222 pub confidence: f32,
223
224 #[serde(default)]
226 pub validated: bool,
227
228 #[serde(skip_serializing_if = "Option::is_none")]
230 pub category: Option<String>,
231
232 #[serde(skip_serializing_if = "Option::is_none")]
234 pub parameters: Option<serde_json::Value>,
235}
236
237impl GeneratedExample {
238 pub fn new(command: impl Into<String>, explanation: impl Into<String>) -> Self {
240 Self {
241 command: command.into(),
242 explanation: explanation.into(),
243 confidence: 0.0,
244 validated: false,
245 category: None,
246 parameters: None,
247 }
248 }
249
250 pub fn with_confidence(mut self, confidence: f32) -> Self {
252 self.confidence = confidence.clamp(0.0, 1.0);
253 self
254 }
255
256 pub fn with_validated(mut self, validated: bool) -> Self {
258 self.validated = validated;
259 self
260 }
261
262 pub fn with_category(mut self, category: impl Into<String>) -> Self {
264 self.category = Some(category.into());
265 self
266 }
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize)]
275pub struct AgentStep {
276 pub step_number: usize,
278
279 pub thought: String,
281
282 #[serde(skip_serializing_if = "Option::is_none")]
284 pub follow_up_question: Option<String>,
285
286 #[serde(skip_serializing_if = "Option::is_none")]
288 pub search_results: Option<Vec<SearchResultRef>>,
289
290 pub is_final: bool,
292
293 #[serde(skip_serializing_if = "Option::is_none")]
295 pub final_answer: Option<String>,
296}
297
298impl AgentStep {
299 pub fn thinking(step_number: usize, thought: impl Into<String>) -> Self {
301 Self {
302 step_number,
303 thought: thought.into(),
304 follow_up_question: None,
305 search_results: None,
306 is_final: false,
307 final_answer: None,
308 }
309 }
310
311 pub fn follow_up(step_number: usize, thought: impl Into<String>, question: impl Into<String>) -> Self {
313 Self {
314 step_number,
315 thought: thought.into(),
316 follow_up_question: Some(question.into()),
317 search_results: None,
318 is_final: false,
319 final_answer: None,
320 }
321 }
322
323 pub fn final_answer(step_number: usize, thought: impl Into<String>, answer: impl Into<String>) -> Self {
325 Self {
326 step_number,
327 thought: thought.into(),
328 follow_up_question: None,
329 search_results: None,
330 is_final: true,
331 final_answer: Some(answer.into()),
332 }
333 }
334
335 pub fn with_search_results(mut self, results: Vec<SearchResultRef>) -> Self {
337 self.search_results = Some(results);
338 self
339 }
340}
341
342#[derive(Debug, Clone, Serialize, Deserialize)]
344pub struct SearchResultRef {
345 pub tool_name: String,
347 pub description: String,
349 pub score: f32,
351}
352
353pub struct GenerationStreamBuilder {
359 tool_name: String,
360 total_tools: usize,
361 current_index: usize,
362}
363
364impl GenerationStreamBuilder {
365 pub fn new(tool_name: impl Into<String>, total_tools: usize, current_index: usize) -> Self {
367 Self {
368 tool_name: tool_name.into(),
369 total_tools,
370 current_index,
371 }
372 }
373
374 pub fn started(&self) -> GenerationEvent {
376 GenerationEvent::started(&self.tool_name, self.total_tools, self.current_index)
377 }
378
379 pub fn thinking(&self, thought: impl Into<String>) -> GenerationEvent {
381 GenerationEvent::thinking(thought)
382 }
383
384 pub fn example(&self, example: GeneratedExample) -> GenerationEvent {
386 GenerationEvent::Example { example }
387 }
388
389 pub fn validation(&self, valid: bool, errors: Vec<String>, example_index: usize) -> GenerationEvent {
391 GenerationEvent::Validation {
392 valid,
393 errors,
394 example_index,
395 }
396 }
397
398 pub fn tool_completed(&self, examples_generated: usize, valid_examples: usize, duration: Duration) -> GenerationEvent {
400 GenerationEvent::ToolCompleted {
401 tool_name: self.tool_name.clone(),
402 examples_generated,
403 valid_examples,
404 duration_ms: duration.as_millis() as u64,
405 }
406 }
407
408 pub fn error(&self, message: impl Into<String>, recoverable: bool) -> GenerationEvent {
410 GenerationEvent::tool_error(message, &self.tool_name, recoverable)
411 }
412}
413
414#[cfg(test)]
419mod tests {
420 use super::*;
421
422 #[test]
423 fn test_generation_event_serialization() {
424 let event = GenerationEvent::started("kubernetes:apply", 10, 1);
425 let json = serde_json::to_string(&event).unwrap();
426 assert!(json.contains("\"type\":\"started\""));
427 assert!(json.contains("\"tool_name\":\"kubernetes:apply\""));
428 assert!(json.contains("\"total_tools\":10"));
429
430 let parsed: GenerationEvent = serde_json::from_str(&json).unwrap();
432 if let GenerationEvent::Started { tool_name, total_tools, current_index } = parsed {
433 assert_eq!(tool_name, "kubernetes:apply");
434 assert_eq!(total_tools, 10);
435 assert_eq!(current_index, 1);
436 } else {
437 panic!("Expected Started event");
438 }
439 }
440
441 #[test]
442 fn test_thinking_event() {
443 let event = GenerationEvent::thinking("Analyzing parameter schema...");
444 let json = serde_json::to_string(&event).unwrap();
445 assert!(json.contains("\"type\":\"thinking\""));
446 assert!(json.contains("Analyzing parameter schema"));
447 }
448
449 #[test]
450 fn test_progress_event() {
451 let event = GenerationEvent::progress(5, 10, Some("Processing tools".to_string()));
452 if let GenerationEvent::Progress { current, total, percent, message } = event {
453 assert_eq!(current, 5);
454 assert_eq!(total, 10);
455 assert!((percent - 50.0).abs() < 0.01);
456 assert_eq!(message, Some("Processing tools".to_string()));
457 } else {
458 panic!("Expected Progress event");
459 }
460 }
461
462 #[test]
463 fn test_error_event() {
464 let event = GenerationEvent::tool_error("Connection timeout", "k8s:apply", true);
465 let json = serde_json::to_string(&event).unwrap();
466 assert!(json.contains("\"type\":\"error\""));
467 assert!(json.contains("\"recoverable\":true"));
468 assert!(json.contains("\"tool_name\":\"k8s:apply\""));
469 }
470
471 #[test]
472 fn test_generated_example() {
473 let example = GeneratedExample::new(
474 "skill run k8s:apply --file=deploy.yaml",
475 "Apply a Kubernetes deployment manifest"
476 )
477 .with_confidence(0.95)
478 .with_validated(true)
479 .with_category("deployment");
480
481 assert_eq!(example.command, "skill run k8s:apply --file=deploy.yaml");
482 assert!((example.confidence - 0.95).abs() < 0.01);
483 assert!(example.validated);
484 assert_eq!(example.category, Some("deployment".to_string()));
485 }
486
487 #[test]
488 fn test_agent_step() {
489 let step = AgentStep::follow_up(
490 1,
491 "I need to find tools for container deployment",
492 "What tools handle Kubernetes deployments?"
493 );
494
495 assert_eq!(step.step_number, 1);
496 assert!(!step.is_final);
497 assert!(step.follow_up_question.is_some());
498 assert!(step.final_answer.is_none());
499
500 let final_step = AgentStep::final_answer(
501 3,
502 "Based on my search, I recommend using kubernetes:apply",
503 "Use kubernetes:apply with --file flag to deploy your manifest"
504 );
505
506 assert!(final_step.is_final);
507 assert!(final_step.final_answer.is_some());
508 }
509
510 #[test]
511 fn test_sse_format() {
512 let event = GenerationEvent::thinking("Processing...");
513 let sse = event.to_sse();
514
515 assert!(sse.starts_with("event: thinking\n"));
516 assert!(sse.contains("data: "));
517 assert!(sse.ends_with("\n\n"));
518 }
519
520 #[test]
521 fn test_stream_builder() {
522 let builder = GenerationStreamBuilder::new("docker:build", 5, 2);
523
524 let started = builder.started();
525 if let GenerationEvent::Started { tool_name, total_tools, current_index } = started {
526 assert_eq!(tool_name, "docker:build");
527 assert_eq!(total_tools, 5);
528 assert_eq!(current_index, 2);
529 }
530
531 let example = GeneratedExample::new("skill run docker:build .", "Build Docker image");
532 let event = builder.example(example);
533 assert!(matches!(event, GenerationEvent::Example { .. }));
534 }
535
536 #[test]
537 fn test_completed_event() {
538 let event = GenerationEvent::completed(50, 45, 10, Duration::from_secs(30));
539 if let GenerationEvent::Completed { total_examples, total_valid, total_tools, duration_ms } = event {
540 assert_eq!(total_examples, 50);
541 assert_eq!(total_valid, 45);
542 assert_eq!(total_tools, 10);
543 assert_eq!(duration_ms, 30000);
544 } else {
545 panic!("Expected Completed event");
546 }
547 }
548}