1use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Protocol {
11 pub id: String,
13
14 pub name: String,
16
17 pub version: String,
19
20 pub description: String,
22
23 pub strategy: ReasoningStrategy,
25
26 pub input: InputSpec,
28
29 pub steps: Vec<ProtocolStep>,
31
32 pub output: OutputSpec,
34
35 #[serde(default)]
37 pub validation: Vec<ValidationRule>,
38
39 #[serde(default)]
41 pub metadata: ProtocolMetadata,
42}
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46#[serde(rename_all = "snake_case")]
47#[derive(Default)]
48pub enum ReasoningStrategy {
49 Expansive,
51 Deductive,
53 #[default]
55 Analytical,
56 Adversarial,
58 Verification,
60 Decision,
62 Empirical,
64}
65
66#[derive(Debug, Clone, Default, Serialize, Deserialize)]
68pub struct InputSpec {
69 #[serde(default)]
71 pub required: Vec<String>,
72
73 #[serde(default)]
75 pub optional: Vec<String>,
76}
77
78#[derive(Debug, Clone, Default, Serialize, Deserialize)]
80pub struct OutputSpec {
81 pub format: String,
83
84 #[serde(default)]
86 pub fields: Vec<String>,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct ProtocolStep {
92 pub id: String,
94
95 pub action: StepAction,
97
98 pub prompt_template: String,
100
101 pub output_format: StepOutputFormat,
103
104 #[serde(default = "default_min_confidence")]
106 pub min_confidence: f64,
107
108 #[serde(default)]
110 pub depends_on: Vec<String>,
111
112 #[serde(default)]
114 pub branch: Option<BranchCondition>,
115}
116
117fn default_min_confidence() -> f64 {
118 0.7
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123#[serde(tag = "type", rename_all = "snake_case")]
124pub enum StepAction {
125 Generate {
127 #[serde(default = "default_min_count")]
129 min_count: usize,
130 #[serde(default = "default_max_count")]
132 max_count: usize,
133 },
134
135 Analyze {
137 #[serde(default)]
139 criteria: Vec<String>,
140 },
141
142 Synthesize {
144 #[serde(default)]
146 aggregation: AggregationType,
147 },
148
149 Validate {
151 #[serde(default)]
153 rules: Vec<String>,
154 },
155
156 Critique {
158 #[serde(default)]
160 severity: CritiqueSeverity,
161 },
162
163 Decide {
165 #[serde(default)]
167 method: DecisionMethod,
168 },
169
170 CrossReference {
172 #[serde(default = "default_min_sources")]
174 min_sources: usize,
175 },
176}
177
178fn default_min_count() -> usize {
179 3
180}
181
182fn default_max_count() -> usize {
183 10
184}
185
186fn default_min_sources() -> usize {
187 3
188}
189
190#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
192#[serde(rename_all = "snake_case")]
193pub enum StepOutputFormat {
194 #[default]
196 Text,
197 List,
199 Structured,
201 Score,
203 Boolean,
205}
206
207#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
209#[serde(rename_all = "snake_case")]
210pub enum AggregationType {
211 #[default]
213 ThematicClustering,
214 Concatenate,
216 WeightedMerge,
218 Consensus,
220}
221
222#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
224#[serde(rename_all = "snake_case")]
225pub enum CritiqueSeverity {
226 Light,
228 #[default]
230 Standard,
231 Adversarial,
233 Brutal,
235}
236
237#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
239#[serde(rename_all = "snake_case")]
240pub enum DecisionMethod {
241 #[default]
243 ProsCons,
244 MultiCriteria,
246 ExpectedValue,
248 RegretMinimization,
250}
251
252#[derive(Debug, Clone, Serialize, Deserialize)]
254#[serde(tag = "type", rename_all = "snake_case")]
255pub enum BranchCondition {
256 ConfidenceBelow {
258 threshold: f64,
260 },
261 ConfidenceAbove {
263 threshold: f64,
265 },
266 OutputEquals {
268 field: String,
270 value: String,
272 },
273 Always,
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
279#[serde(tag = "rule", rename_all = "snake_case")]
280pub enum ValidationRule {
281 MinCount {
283 field: String,
285 value: usize,
287 },
288 MaxCount {
290 field: String,
292 value: usize,
294 },
295 ConfidenceRange {
297 min: f64,
299 max: f64,
301 },
302 Required {
304 field: String,
306 },
307 Custom {
309 expression: String,
311 },
312}
313
314#[derive(Debug, Clone, Default, Serialize, Deserialize)]
316pub struct ProtocolMetadata {
317 #[serde(default)]
319 pub category: String,
320
321 #[serde(default)]
323 pub composable_with: Vec<String>,
324
325 #[serde(default)]
327 pub typical_tokens: u32,
328
329 #[serde(default)]
331 pub estimated_latency_ms: u32,
332
333 #[serde(default)]
335 pub extra: HashMap<String, serde_json::Value>,
336}
337
338impl Protocol {
339 pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
341 Self {
342 id: id.into(),
343 name: name.into(),
344 version: "1.0.0".to_string(),
345 description: String::new(),
346 strategy: ReasoningStrategy::default(),
347 input: InputSpec::default(),
348 steps: Vec::new(),
349 output: OutputSpec::default(),
350 validation: Vec::new(),
351 metadata: ProtocolMetadata::default(),
352 }
353 }
354
355 pub fn with_step(mut self, step: ProtocolStep) -> Self {
357 self.steps.push(step);
358 self
359 }
360
361 pub fn with_strategy(mut self, strategy: ReasoningStrategy) -> Self {
363 self.strategy = strategy;
364 self
365 }
366
367 pub fn validate(&self) -> Result<(), Vec<String>> {
369 let mut errors = Vec::new();
370
371 if self.id.is_empty() {
372 errors.push("Protocol ID cannot be empty".to_string());
373 }
374
375 if self.steps.is_empty() {
376 errors.push("Protocol must have at least one step".to_string());
377 }
378
379 let step_ids: Vec<&str> = self.steps.iter().map(|s| s.id.as_str()).collect();
381 for step in &self.steps {
382 for dep in &step.depends_on {
383 if !step_ids.contains(&dep.as_str()) {
384 errors.push(format!(
385 "Step '{}' depends on unknown step '{}'",
386 step.id, dep
387 ));
388 }
389 }
390 }
391
392 if errors.is_empty() {
393 Ok(())
394 } else {
395 Err(errors)
396 }
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
405 fn test_protocol_creation() {
406 let protocol = Protocol::new("test", "Test Protocol")
407 .with_strategy(ReasoningStrategy::Expansive)
408 .with_step(ProtocolStep {
409 id: "step1".to_string(),
410 action: StepAction::Generate {
411 min_count: 5,
412 max_count: 10,
413 },
414 prompt_template: "Generate ideas for: {{query}}".to_string(),
415 output_format: StepOutputFormat::List,
416 min_confidence: 0.7,
417 depends_on: Vec::new(),
418 branch: None,
419 });
420
421 assert_eq!(protocol.id, "test");
422 assert_eq!(protocol.steps.len(), 1);
423 assert!(protocol.validate().is_ok());
424 }
425
426 #[test]
427 fn test_protocol_validation_empty_steps() {
428 let protocol = Protocol::new("test", "Test Protocol");
429 let result = protocol.validate();
430 assert!(result.is_err());
431 assert!(result
432 .unwrap_err()
433 .iter()
434 .any(|e| e.contains("at least one step")));
435 }
436
437 #[test]
438 fn test_step_action_serialization() {
439 let action = StepAction::Generate {
440 min_count: 5,
441 max_count: 10,
442 };
443 let json = serde_json::to_string(&action).expect("Failed to serialize");
444 assert!(json.contains("generate"));
445
446 let parsed: StepAction = serde_json::from_str(&json).expect("Failed to deserialize");
447 match parsed {
448 StepAction::Generate {
449 min_count,
450 max_count,
451 } => {
452 assert_eq!(min_count, 5);
453 assert_eq!(max_count, 10);
454 }
455 _ => panic!("Wrong action type"),
456 }
457 }
458}