ricecoder_learning/
models.rs

1/// Core data models for the learning system
2use chrono::{DateTime, Utc};
3use serde::{Deserialize, Serialize};
4use std::path::PathBuf;
5use uuid::Uuid;
6
7/// Scope where rules are stored and applied
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9pub enum RuleScope {
10    /// Global scope: ~/.ricecoder/rules/
11    Global,
12    /// Project scope: ./.ricecoder/rules/
13    Project,
14    /// Session-only (in-memory)
15    Session,
16}
17
18impl std::fmt::Display for RuleScope {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        match self {
21            RuleScope::Global => write!(f, "global"),
22            RuleScope::Project => write!(f, "project"),
23            RuleScope::Session => write!(f, "session"),
24        }
25    }
26}
27
28/// Source of a rule
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30pub enum RuleSource {
31    /// Automatically captured from user decisions
32    Learned,
33    /// User-defined rule
34    Manual,
35    /// Promoted from project to global scope
36    Promoted,
37}
38
39impl std::fmt::Display for RuleSource {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            RuleSource::Learned => write!(f, "learned"),
43            RuleSource::Manual => write!(f, "manual"),
44            RuleSource::Promoted => write!(f, "promoted"),
45        }
46    }
47}
48
49/// A learned rule that guides code generation
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct Rule {
52    /// Unique identifier for the rule
53    pub id: String,
54    /// Scope where this rule is stored
55    pub scope: RuleScope,
56    /// Pattern that triggers this rule
57    pub pattern: String,
58    /// Action to take when pattern matches
59    pub action: String,
60    /// Source of the rule
61    pub source: RuleSource,
62    /// When the rule was created
63    pub created_at: DateTime<Utc>,
64    /// When the rule was last updated
65    pub updated_at: DateTime<Utc>,
66    /// Version number of the rule
67    pub version: u32,
68    /// Confidence score (0.0 to 1.0)
69    pub confidence: f32,
70    /// Number of times this rule has been applied
71    pub usage_count: u64,
72    /// Success rate (0.0 to 1.0)
73    pub success_rate: f32,
74    /// Additional metadata
75    pub metadata: serde_json::Value,
76}
77
78impl Rule {
79    /// Create a new rule
80    pub fn new(
81        scope: RuleScope,
82        pattern: String,
83        action: String,
84        source: RuleSource,
85    ) -> Self {
86        Self {
87            id: Uuid::new_v4().to_string(),
88            scope,
89            pattern,
90            action,
91            source,
92            created_at: Utc::now(),
93            updated_at: Utc::now(),
94            version: 1,
95            confidence: 0.5,
96            usage_count: 0,
97            success_rate: 0.0,
98            metadata: serde_json::json!({}),
99        }
100    }
101}
102
103/// Context in which a decision was made
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct DecisionContext {
106    /// Path to the project
107    pub project_path: PathBuf,
108    /// Path to the file being edited
109    pub file_path: PathBuf,
110    /// Line number in the file
111    pub line_number: u32,
112    /// Type of agent making the decision
113    pub agent_type: String,
114}
115
116/// A captured user decision
117#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct Decision {
119    /// Unique identifier for the decision
120    pub id: String,
121    /// When the decision was made
122    pub timestamp: DateTime<Utc>,
123    /// Context in which the decision was made
124    pub context: DecisionContext,
125    /// Type of decision (e.g., "code_generation", "refactoring")
126    pub decision_type: String,
127    /// Input that led to the decision
128    pub input: serde_json::Value,
129    /// Output of the decision
130    pub output: serde_json::Value,
131    /// Additional metadata
132    pub metadata: serde_json::Value,
133}
134
135impl Decision {
136    /// Create a new decision
137    pub fn new(
138        context: DecisionContext,
139        decision_type: String,
140        input: serde_json::Value,
141        output: serde_json::Value,
142    ) -> Self {
143        Self {
144            id: Uuid::new_v4().to_string(),
145            timestamp: Utc::now(),
146            context,
147            decision_type,
148            input,
149            output,
150            metadata: serde_json::json!({}),
151        }
152    }
153}
154
155/// Example of a pattern
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct PatternExample {
158    /// Input that produced this example
159    pub input: serde_json::Value,
160    /// Output of this example
161    pub output: serde_json::Value,
162    /// Context of this example
163    pub context: serde_json::Value,
164}
165
166/// A learned pattern extracted from repeated decisions
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct LearnedPattern {
169    /// Unique identifier for the pattern
170    pub id: String,
171    /// Type of pattern (e.g., "code_generation", "refactoring")
172    pub pattern_type: String,
173    /// Human-readable description
174    pub description: String,
175    /// Examples of this pattern
176    pub examples: Vec<PatternExample>,
177    /// Confidence score (0.0 to 1.0)
178    pub confidence: f32,
179    /// Number of times this pattern has been observed
180    pub occurrences: usize,
181    /// When the pattern was first identified
182    pub created_at: DateTime<Utc>,
183    /// When the pattern was last observed
184    pub last_seen: DateTime<Utc>,
185}
186
187impl LearnedPattern {
188    /// Create a new pattern
189    pub fn new(pattern_type: String, description: String) -> Self {
190        Self {
191            id: Uuid::new_v4().to_string(),
192            pattern_type,
193            description,
194            examples: Vec::new(),
195            confidence: 0.0,
196            occurrences: 0,
197            created_at: Utc::now(),
198            last_seen: Utc::now(),
199        }
200    }
201}
202
203/// Learning system configuration
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct LearningConfig {
206    /// Scope for learning (global, project, or session)
207    pub scope: RuleScope,
208    /// Whether learning is enabled
209    pub enabled: bool,
210    /// Whether approval is required for new rules
211    pub approval_required: bool,
212    /// Whether to automatically promote rules
213    pub auto_promote: bool,
214    /// How many days to retain rules
215    pub retention_days: u32,
216    /// Maximum number of rules to store
217    pub max_rules: usize,
218}
219
220impl Default for LearningConfig {
221    fn default() -> Self {
222        Self {
223            scope: RuleScope::Global,
224            enabled: true,
225            approval_required: false,
226            auto_promote: false,
227            retention_days: 365,
228            max_rules: 10000,
229        }
230    }
231}
232
233impl LearningConfig {
234    /// Create a new configuration with default values
235    pub fn new(scope: RuleScope) -> Self {
236        Self {
237            scope,
238            ..Default::default()
239        }
240    }
241
242    /// Validate the configuration
243    pub fn validate(&self) -> crate::error::Result<()> {
244        if self.retention_days == 0 {
245            return Err(crate::error::LearningError::ConfigurationError(
246                "retention_days must be greater than 0".to_string(),
247            ));
248        }
249
250        if self.max_rules == 0 {
251            return Err(crate::error::LearningError::ConfigurationError(
252                "max_rules must be greater than 0".to_string(),
253            ));
254        }
255
256        Ok(())
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263
264    #[test]
265    fn test_rule_creation() {
266        let rule = Rule::new(
267            RuleScope::Global,
268            "pattern".to_string(),
269            "action".to_string(),
270            RuleSource::Learned,
271        );
272
273        assert_eq!(rule.scope, RuleScope::Global);
274        assert_eq!(rule.pattern, "pattern");
275        assert_eq!(rule.action, "action");
276        assert_eq!(rule.source, RuleSource::Learned);
277        assert_eq!(rule.version, 1);
278        assert_eq!(rule.confidence, 0.5);
279        assert_eq!(rule.usage_count, 0);
280    }
281
282    #[test]
283    fn test_rule_serialization() {
284        let rule = Rule::new(
285            RuleScope::Project,
286            "test_pattern".to_string(),
287            "test_action".to_string(),
288            RuleSource::Manual,
289        );
290
291        let json = serde_json::to_string(&rule).expect("Failed to serialize");
292        let deserialized: Rule = serde_json::from_str(&json).expect("Failed to deserialize");
293
294        assert_eq!(rule.id, deserialized.id);
295        assert_eq!(rule.scope, deserialized.scope);
296        assert_eq!(rule.pattern, deserialized.pattern);
297    }
298
299    #[test]
300    fn test_decision_creation() {
301        let context = DecisionContext {
302            project_path: PathBuf::from("/project"),
303            file_path: PathBuf::from("/project/src/main.rs"),
304            line_number: 42,
305            agent_type: "code_generator".to_string(),
306        };
307
308        let decision = Decision::new(
309            context.clone(),
310            "code_generation".to_string(),
311            serde_json::json!({"input": "test"}),
312            serde_json::json!({"output": "result"}),
313        );
314
315        assert_eq!(decision.decision_type, "code_generation");
316        assert_eq!(decision.context.line_number, 42);
317    }
318
319    #[test]
320    fn test_decision_serialization() {
321        let context = DecisionContext {
322            project_path: PathBuf::from("/project"),
323            file_path: PathBuf::from("/project/src/main.rs"),
324            line_number: 10,
325            agent_type: "test_agent".to_string(),
326        };
327
328        let decision = Decision::new(
329            context,
330            "test_type".to_string(),
331            serde_json::json!({}),
332            serde_json::json!({}),
333        );
334
335        let json = serde_json::to_string(&decision).expect("Failed to serialize");
336        let deserialized: Decision = serde_json::from_str(&json).expect("Failed to deserialize");
337
338        assert_eq!(decision.id, deserialized.id);
339        assert_eq!(decision.decision_type, deserialized.decision_type);
340    }
341
342    #[test]
343    fn test_pattern_creation() {
344        let pattern = LearnedPattern::new(
345            "code_generation".to_string(),
346            "Test pattern".to_string(),
347        );
348
349        assert_eq!(pattern.pattern_type, "code_generation");
350        assert_eq!(pattern.description, "Test pattern");
351        assert_eq!(pattern.occurrences, 0);
352        assert_eq!(pattern.confidence, 0.0);
353    }
354
355    #[test]
356    fn test_pattern_serialization() {
357        let pattern = LearnedPattern::new(
358            "refactoring".to_string(),
359            "Refactoring pattern".to_string(),
360        );
361
362        let json = serde_json::to_string(&pattern).expect("Failed to serialize");
363        let deserialized: LearnedPattern =
364            serde_json::from_str(&json).expect("Failed to deserialize");
365
366        assert_eq!(pattern.id, deserialized.id);
367        assert_eq!(pattern.pattern_type, deserialized.pattern_type);
368    }
369
370    #[test]
371    fn test_learning_config_default() {
372        let config = LearningConfig::default();
373
374        assert_eq!(config.scope, RuleScope::Global);
375        assert!(config.enabled);
376        assert!(!config.approval_required);
377        assert!(!config.auto_promote);
378        assert_eq!(config.retention_days, 365);
379        assert_eq!(config.max_rules, 10000);
380    }
381
382    #[test]
383    fn test_learning_config_validation() {
384        let mut config = LearningConfig::default();
385        assert!(config.validate().is_ok());
386
387        config.retention_days = 0;
388        assert!(config.validate().is_err());
389
390        config.retention_days = 365;
391        config.max_rules = 0;
392        assert!(config.validate().is_err());
393    }
394
395    #[test]
396    fn test_rule_scope_display() {
397        assert_eq!(RuleScope::Global.to_string(), "global");
398        assert_eq!(RuleScope::Project.to_string(), "project");
399        assert_eq!(RuleScope::Session.to_string(), "session");
400    }
401
402    #[test]
403    fn test_rule_source_display() {
404        assert_eq!(RuleSource::Learned.to_string(), "learned");
405        assert_eq!(RuleSource::Manual.to_string(), "manual");
406        assert_eq!(RuleSource::Promoted.to_string(), "promoted");
407    }
408}