reasonkit/thinktool/
protocol.rs

1//! Protocol definition types
2//!
3//! Defines the schema for ThinkTool protocols.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8/// A ThinkTool Protocol definition
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct Protocol {
11    /// Unique protocol identifier (e.g., "gigathink", "laserlogic")
12    pub id: String,
13
14    /// Human-readable name
15    pub name: String,
16
17    /// Protocol version (semver)
18    pub version: String,
19
20    /// Brief description
21    pub description: String,
22
23    /// Reasoning strategy category
24    pub strategy: ReasoningStrategy,
25
26    /// Input specification
27    pub input: InputSpec,
28
29    /// Protocol steps (ordered)
30    pub steps: Vec<ProtocolStep>,
31
32    /// Output specification
33    pub output: OutputSpec,
34
35    /// Validation rules
36    #[serde(default)]
37    pub validation: Vec<ValidationRule>,
38
39    /// Metadata for composition
40    #[serde(default)]
41    pub metadata: ProtocolMetadata,
42}
43
44/// Reasoning strategy categories
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
46#[serde(rename_all = "snake_case")]
47#[derive(Default)]
48pub enum ReasoningStrategy {
49    /// Divergent thinking - maximize perspectives
50    Expansive,
51    /// Convergent thinking - deduce conclusions
52    Deductive,
53    /// Break down to fundamentals
54    #[default]
55    Analytical,
56    /// Challenge and critique
57    Adversarial,
58    /// Cross-reference and confirm
59    Verification,
60    /// Weigh options systematically
61    Decision,
62    /// Scientific method
63    Empirical,
64}
65
66/// Input specification for a protocol
67#[derive(Debug, Clone, Default, Serialize, Deserialize)]
68pub struct InputSpec {
69    /// Required input fields
70    #[serde(default)]
71    pub required: Vec<String>,
72
73    /// Optional input fields
74    #[serde(default)]
75    pub optional: Vec<String>,
76}
77
78/// Output specification for a protocol
79#[derive(Debug, Clone, Default, Serialize, Deserialize)]
80pub struct OutputSpec {
81    /// Output format name
82    pub format: String,
83
84    /// Output fields
85    #[serde(default)]
86    pub fields: Vec<String>,
87}
88
89/// A single step in a protocol
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct ProtocolStep {
92    /// Step identifier within protocol
93    pub id: String,
94
95    /// What this step does
96    pub action: StepAction,
97
98    /// Prompt template (with {{placeholders}})
99    pub prompt_template: String,
100
101    /// Expected output format
102    pub output_format: StepOutputFormat,
103
104    /// Minimum confidence to proceed (0.0 - 1.0)
105    #[serde(default = "default_min_confidence")]
106    pub min_confidence: f64,
107
108    /// Dependencies on previous steps
109    #[serde(default)]
110    pub depends_on: Vec<String>,
111
112    /// Optional branching conditions
113    #[serde(default)]
114    pub branch: Option<BranchCondition>,
115}
116
117fn default_min_confidence() -> f64 {
118    0.7
119}
120
121/// Step action types
122#[derive(Debug, Clone, Serialize, Deserialize)]
123#[serde(tag = "type", rename_all = "snake_case")]
124pub enum StepAction {
125    /// Generate perspectives/ideas
126    Generate {
127        /// Minimum number of items to generate
128        #[serde(default = "default_min_count")]
129        min_count: usize,
130        /// Maximum number of items to generate
131        #[serde(default = "default_max_count")]
132        max_count: usize,
133    },
134
135    /// Analyze/evaluate input
136    Analyze {
137        /// Criteria for analysis
138        #[serde(default)]
139        criteria: Vec<String>,
140    },
141
142    /// Synthesize multiple inputs
143    Synthesize {
144        /// Aggregation method to use
145        #[serde(default)]
146        aggregation: AggregationType,
147    },
148
149    /// Validate against rules
150    Validate {
151        /// Validation rules to apply
152        #[serde(default)]
153        rules: Vec<String>,
154    },
155
156    /// Challenge/critique
157    Critique {
158        /// Severity level for critique
159        #[serde(default)]
160        severity: CritiqueSeverity,
161    },
162
163    /// Make decision
164    Decide {
165        /// Decision method to use
166        #[serde(default)]
167        method: DecisionMethod,
168    },
169
170    /// Cross-reference sources
171    CrossReference {
172        /// Minimum number of sources required
173        #[serde(default = "default_min_sources")]
174        min_sources: usize,
175    },
176}
177
178fn default_min_count() -> usize {
179    3
180}
181
182fn default_max_count() -> usize {
183    10
184}
185
186fn default_min_sources() -> usize {
187    3
188}
189
190/// Output format for a step
191#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
192#[serde(rename_all = "snake_case")]
193pub enum StepOutputFormat {
194    /// Free-form text
195    #[default]
196    Text,
197    /// Numbered/bulleted list
198    List,
199    /// Key-value structured data
200    Structured,
201    /// Numeric score (0.0 - 1.0)
202    Score,
203    /// Boolean decision
204    Boolean,
205}
206
207/// Aggregation types for synthesis
208#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
209#[serde(rename_all = "snake_case")]
210pub enum AggregationType {
211    /// Group by themes
212    #[default]
213    ThematicClustering,
214    /// Simple concatenation
215    Concatenate,
216    /// Weighted by confidence
217    WeightedMerge,
218    /// Majority voting
219    Consensus,
220}
221
222/// Severity levels for critique
223#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
224#[serde(rename_all = "snake_case")]
225pub enum CritiqueSeverity {
226    /// Light review
227    Light,
228    /// Standard critique
229    #[default]
230    Standard,
231    /// Adversarial challenge
232    Adversarial,
233    /// Maximum scrutiny
234    Brutal,
235}
236
237/// Methods for decision making
238#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
239#[serde(rename_all = "snake_case")]
240pub enum DecisionMethod {
241    /// Simple pros/cons
242    #[default]
243    ProsCons,
244    /// Multi-criteria analysis
245    MultiCriteria,
246    /// Expected value calculation
247    ExpectedValue,
248    /// Regret minimization
249    RegretMinimization,
250}
251
252/// Conditional branching
253#[derive(Debug, Clone, Serialize, Deserialize)]
254#[serde(tag = "type", rename_all = "snake_case")]
255pub enum BranchCondition {
256    /// Branch if confidence below threshold
257    ConfidenceBelow {
258        /// Confidence threshold value
259        threshold: f64,
260    },
261    /// Branch if confidence above threshold
262    ConfidenceAbove {
263        /// Confidence threshold value
264        threshold: f64,
265    },
266    /// Branch based on output value
267    OutputEquals {
268        /// Field name to check
269        field: String,
270        /// Expected value
271        value: String,
272    },
273    /// Always execute (unconditional)
274    Always,
275}
276
277/// Validation rule for protocol output
278#[derive(Debug, Clone, Serialize, Deserialize)]
279#[serde(tag = "rule", rename_all = "snake_case")]
280pub enum ValidationRule {
281    /// Minimum number of items
282    MinCount {
283        /// Field name to validate
284        field: String,
285        /// Minimum count value
286        value: usize,
287    },
288    /// Maximum number of items
289    MaxCount {
290        /// Field name to validate
291        field: String,
292        /// Maximum count value
293        value: usize,
294    },
295    /// Confidence must be in range
296    ConfidenceRange {
297        /// Minimum confidence value
298        min: f64,
299        /// Maximum confidence value
300        max: f64,
301    },
302    /// Field must be present
303    Required {
304        /// Required field name
305        field: String,
306    },
307    /// Custom validation (expression)
308    Custom {
309        /// Validation expression
310        expression: String,
311    },
312}
313
314/// Protocol metadata for composition and optimization
315#[derive(Debug, Clone, Default, Serialize, Deserialize)]
316pub struct ProtocolMetadata {
317    /// Category tag
318    #[serde(default)]
319    pub category: String,
320
321    /// Protocols this can be composed with
322    #[serde(default)]
323    pub composable_with: Vec<String>,
324
325    /// Typical token usage
326    #[serde(default)]
327    pub typical_tokens: u32,
328
329    /// Estimated latency in milliseconds
330    #[serde(default)]
331    pub estimated_latency_ms: u32,
332
333    /// Additional key-value metadata
334    #[serde(default)]
335    pub extra: HashMap<String, serde_json::Value>,
336}
337
338impl Protocol {
339    /// Create a new protocol with required fields
340    pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
341        Self {
342            id: id.into(),
343            name: name.into(),
344            version: "1.0.0".to_string(),
345            description: String::new(),
346            strategy: ReasoningStrategy::default(),
347            input: InputSpec::default(),
348            steps: Vec::new(),
349            output: OutputSpec::default(),
350            validation: Vec::new(),
351            metadata: ProtocolMetadata::default(),
352        }
353    }
354
355    /// Add a step to the protocol
356    pub fn with_step(mut self, step: ProtocolStep) -> Self {
357        self.steps.push(step);
358        self
359    }
360
361    /// Set the reasoning strategy
362    pub fn with_strategy(mut self, strategy: ReasoningStrategy) -> Self {
363        self.strategy = strategy;
364        self
365    }
366
367    /// Validate protocol definition
368    pub fn validate(&self) -> Result<(), Vec<String>> {
369        let mut errors = Vec::new();
370
371        if self.id.is_empty() {
372            errors.push("Protocol ID cannot be empty".to_string());
373        }
374
375        if self.steps.is_empty() {
376            errors.push("Protocol must have at least one step".to_string());
377        }
378
379        // Check step dependencies
380        let step_ids: Vec<&str> = self.steps.iter().map(|s| s.id.as_str()).collect();
381        for step in &self.steps {
382            for dep in &step.depends_on {
383                if !step_ids.contains(&dep.as_str()) {
384                    errors.push(format!(
385                        "Step '{}' depends on unknown step '{}'",
386                        step.id, dep
387                    ));
388                }
389            }
390        }
391
392        if errors.is_empty() {
393            Ok(())
394        } else {
395            Err(errors)
396        }
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    fn test_protocol_creation() {
406        let protocol = Protocol::new("test", "Test Protocol")
407            .with_strategy(ReasoningStrategy::Expansive)
408            .with_step(ProtocolStep {
409                id: "step1".to_string(),
410                action: StepAction::Generate {
411                    min_count: 5,
412                    max_count: 10,
413                },
414                prompt_template: "Generate ideas for: {{query}}".to_string(),
415                output_format: StepOutputFormat::List,
416                min_confidence: 0.7,
417                depends_on: Vec::new(),
418                branch: None,
419            });
420
421        assert_eq!(protocol.id, "test");
422        assert_eq!(protocol.steps.len(), 1);
423        assert!(protocol.validate().is_ok());
424    }
425
426    #[test]
427    fn test_protocol_validation_empty_steps() {
428        let protocol = Protocol::new("test", "Test Protocol");
429        let result = protocol.validate();
430        assert!(result.is_err());
431        assert!(result
432            .unwrap_err()
433            .iter()
434            .any(|e| e.contains("at least one step")));
435    }
436
437    #[test]
438    fn test_step_action_serialization() {
439        let action = StepAction::Generate {
440            min_count: 5,
441            max_count: 10,
442        };
443        let json = serde_json::to_string(&action).expect("Failed to serialize");
444        assert!(json.contains("generate"));
445
446        let parsed: StepAction = serde_json::from_str(&json).expect("Failed to deserialize");
447        match parsed {
448            StepAction::Generate {
449                min_count,
450                max_count,
451            } => {
452                assert_eq!(min_count, 5);
453                assert_eq!(max_count, 10);
454            }
455            _ => panic!("Wrong action type"),
456        }
457    }
458}