Skip to main content

rustant_tools/
flashcards.rs

1//! Flashcard tool — spaced repetition learning with SM-2 algorithm.
2
3use async_trait::async_trait;
4use chrono::{DateTime, Duration as ChronoDuration, Utc};
5use rustant_core::error::ToolError;
6use rustant_core::types::{RiskLevel, ToolOutput};
7use serde::{Deserialize, Serialize};
8use serde_json::{Value, json};
9use std::path::PathBuf;
10
11use crate::registry::Tool;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14struct Flashcard {
15    id: usize,
16    deck: String,
17    front: String,
18    back: String,
19    // SM-2 fields
20    easiness_factor: f64,
21    interval_days: f64,
22    repetitions: u32,
23    next_review: DateTime<Utc>,
24    created_at: DateTime<Utc>,
25}
26
27impl Flashcard {
28    fn new(id: usize, deck: &str, front: &str, back: &str) -> Self {
29        Self {
30            id,
31            deck: deck.to_string(),
32            front: front.to_string(),
33            back: back.to_string(),
34            easiness_factor: 2.5,
35            interval_days: 1.0,
36            repetitions: 0,
37            next_review: Utc::now(),
38            created_at: Utc::now(),
39        }
40    }
41
42    /// Apply SM-2 algorithm with quality 0-5.
43    fn review(&mut self, quality: u32) {
44        let q = quality.min(5) as f64;
45        // Update easiness factor
46        self.easiness_factor =
47            (self.easiness_factor + 0.1 - (5.0 - q) * (0.08 + (5.0 - q) * 0.02)).max(1.3);
48
49        if quality < 3 {
50            // Failed — reset
51            self.interval_days = 1.0;
52            self.repetitions = 0;
53        } else {
54            self.repetitions += 1;
55            match self.repetitions {
56                1 => self.interval_days = 1.0,
57                2 => self.interval_days = 6.0,
58                _ => self.interval_days *= self.easiness_factor,
59            }
60        }
61        self.next_review =
62            Utc::now() + ChronoDuration::seconds((self.interval_days * 86400.0) as i64);
63    }
64}
65
66#[derive(Debug, Default, Serialize, Deserialize)]
67struct FlashcardState {
68    cards: Vec<Flashcard>,
69    next_id: usize,
70}
71
72pub struct FlashcardsTool {
73    workspace: PathBuf,
74}
75
76impl FlashcardsTool {
77    pub fn new(workspace: PathBuf) -> Self {
78        Self { workspace }
79    }
80
81    fn state_path(&self) -> PathBuf {
82        self.workspace
83            .join(".rustant")
84            .join("flashcards")
85            .join("cards.json")
86    }
87
88    fn load_state(&self) -> FlashcardState {
89        let path = self.state_path();
90        if path.exists() {
91            std::fs::read_to_string(&path)
92                .ok()
93                .and_then(|s| serde_json::from_str(&s).ok())
94                .unwrap_or_default()
95        } else {
96            FlashcardState {
97                cards: Vec::new(),
98                next_id: 1,
99            }
100        }
101    }
102
103    fn save_state(&self, state: &FlashcardState) -> Result<(), ToolError> {
104        let path = self.state_path();
105        if let Some(parent) = path.parent() {
106            std::fs::create_dir_all(parent).map_err(|e| ToolError::ExecutionFailed {
107                name: "flashcards".to_string(),
108                message: e.to_string(),
109            })?;
110        }
111        let json = serde_json::to_string_pretty(state).map_err(|e| ToolError::ExecutionFailed {
112            name: "flashcards".to_string(),
113            message: e.to_string(),
114        })?;
115        let tmp = path.with_extension("json.tmp");
116        std::fs::write(&tmp, &json).map_err(|e| ToolError::ExecutionFailed {
117            name: "flashcards".to_string(),
118            message: e.to_string(),
119        })?;
120        std::fs::rename(&tmp, &path).map_err(|e| ToolError::ExecutionFailed {
121            name: "flashcards".to_string(),
122            message: e.to_string(),
123        })?;
124        Ok(())
125    }
126}
127
128#[async_trait]
129impl Tool for FlashcardsTool {
130    fn name(&self) -> &str {
131        "flashcards"
132    }
133    fn description(&self) -> &str {
134        "Spaced repetition flashcards with SM-2 algorithm. Actions: add_card, study, answer, list_decks, stats."
135    }
136    fn parameters_schema(&self) -> Value {
137        json!({
138            "type": "object",
139            "properties": {
140                "action": { "type": "string", "enum": ["add_card", "study", "answer", "list_decks", "stats"] },
141                "deck": { "type": "string", "description": "Deck name" },
142                "front": { "type": "string", "description": "Card front (question)" },
143                "back": { "type": "string", "description": "Card back (answer)" },
144                "card_id": { "type": "integer", "description": "Card ID (for answer)" },
145                "quality": { "type": "integer", "description": "Answer quality 0-5 (0=forgot, 3=correct with difficulty, 5=easy)" }
146            },
147            "required": ["action"]
148        })
149    }
150    fn risk_level(&self) -> RiskLevel {
151        RiskLevel::Write
152    }
153
154    async fn execute(&self, args: Value) -> Result<ToolOutput, ToolError> {
155        let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("");
156        let mut state = self.load_state();
157
158        match action {
159            "add_card" => {
160                let deck = args
161                    .get("deck")
162                    .and_then(|v| v.as_str())
163                    .unwrap_or("default");
164                let front = args.get("front").and_then(|v| v.as_str()).unwrap_or("");
165                let back = args.get("back").and_then(|v| v.as_str()).unwrap_or("");
166                if front.is_empty() || back.is_empty() {
167                    return Ok(ToolOutput::text(
168                        "Provide both front and back for the card.",
169                    ));
170                }
171                let id = state.next_id;
172                state.next_id += 1;
173                state.cards.push(Flashcard::new(id, deck, front, back));
174                self.save_state(&state)?;
175                Ok(ToolOutput::text(format!(
176                    "Added card #{} to deck '{}'.",
177                    id, deck
178                )))
179            }
180            "study" => {
181                let deck_filter = args.get("deck").and_then(|v| v.as_str());
182                let now = Utc::now();
183                let due: Vec<&Flashcard> = state
184                    .cards
185                    .iter()
186                    .filter(|c| c.next_review <= now)
187                    .filter(|c| deck_filter.map(|d| c.deck == d).unwrap_or(true))
188                    .take(1)
189                    .collect();
190                if due.is_empty() {
191                    return Ok(ToolOutput::text("No cards due for review. Great job!"));
192                }
193                let card = due[0];
194                Ok(ToolOutput::text(format!(
195                    "Card #{} [{}]\n\nQ: {}\n\n(Use answer action with card_id={} and quality=0-5 to respond)",
196                    card.id, card.deck, card.front, card.id
197                )))
198            }
199            "answer" => {
200                let card_id = args.get("card_id").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
201                let quality = args.get("quality").and_then(|v| v.as_u64()).unwrap_or(3) as u32;
202                if let Some(card) = state.cards.iter_mut().find(|c| c.id == card_id) {
203                    let answer = card.back.clone();
204                    card.review(quality);
205                    let next = card.next_review.format("%Y-%m-%d");
206                    self.save_state(&state)?;
207                    let feedback = match quality {
208                        0..=2 => "Keep studying! Card will appear again soon.",
209                        3 => "Correct! Next review in a day.",
210                        4 => "Good! Interval extended.",
211                        _ => "Perfect! Long interval set.",
212                    };
213                    Ok(ToolOutput::text(format!(
214                        "A: {}\n\n{}\nNext review: {}",
215                        answer, feedback, next
216                    )))
217                } else {
218                    Ok(ToolOutput::text(format!("Card #{} not found.", card_id)))
219                }
220            }
221            "list_decks" => {
222                let mut decks: std::collections::HashMap<&str, (usize, usize)> =
223                    std::collections::HashMap::new();
224                let now = Utc::now();
225                for card in &state.cards {
226                    let entry = decks.entry(&card.deck).or_insert((0, 0));
227                    entry.0 += 1;
228                    if card.next_review <= now {
229                        entry.1 += 1;
230                    }
231                }
232                if decks.is_empty() {
233                    return Ok(ToolOutput::text("No flashcard decks yet."));
234                }
235                let lines: Vec<String> = decks
236                    .iter()
237                    .map(|(d, (total, due))| format!("  {} — {} cards ({} due)", d, total, due))
238                    .collect();
239                Ok(ToolOutput::text(format!("Decks:\n{}", lines.join("\n"))))
240            }
241            "stats" => {
242                let total = state.cards.len();
243                let now = Utc::now();
244                let due = state.cards.iter().filter(|c| c.next_review <= now).count();
245                let avg_ef: f64 = if total > 0 {
246                    state.cards.iter().map(|c| c.easiness_factor).sum::<f64>() / total as f64
247                } else {
248                    0.0
249                };
250                Ok(ToolOutput::text(format!(
251                    "Flashcard stats:\n  Total cards: {}\n  Due now: {}\n  Average EF: {:.2}",
252                    total, due, avg_ef
253                )))
254            }
255            _ => Ok(ToolOutput::text(format!("Unknown action: {}.", action))),
256        }
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use tempfile::TempDir;
264
265    #[test]
266    fn test_sm2_easy_increases_interval() {
267        let mut card = Flashcard::new(1, "test", "Q", "A");
268        assert_eq!(card.interval_days, 1.0);
269        card.review(5); // Perfect
270        assert_eq!(card.repetitions, 1);
271        card.review(5); // Perfect again
272        assert_eq!(card.repetitions, 2);
273        assert!(card.interval_days >= 6.0);
274        card.review(5); // Third time
275        assert!(card.interval_days > 6.0); // Should grow
276    }
277
278    #[test]
279    fn test_sm2_hard_resets() {
280        let mut card = Flashcard::new(1, "test", "Q", "A");
281        card.review(5);
282        card.review(5);
283        assert!(card.interval_days >= 6.0);
284        card.review(1); // Failed
285        assert_eq!(card.repetitions, 0);
286        assert_eq!(card.interval_days, 1.0);
287    }
288
289    #[test]
290    fn test_sm2_easiness_floor() {
291        let mut card = Flashcard::new(1, "test", "Q", "A");
292        for _ in 0..20 {
293            card.review(0); // Worst quality
294        }
295        assert!(card.easiness_factor >= 1.3);
296    }
297
298    #[tokio::test]
299    async fn test_flashcards_add_study() {
300        let dir = TempDir::new().unwrap();
301        let workspace = dir.path().canonicalize().unwrap();
302        let tool = FlashcardsTool::new(workspace);
303        tool.execute(json!({"action": "add_card", "deck": "rust", "front": "What is ownership?", "back": "A memory management system"})).await.unwrap();
304        let result = tool.execute(json!({"action": "study"})).await.unwrap();
305        assert!(result.content.contains("ownership"));
306    }
307
308    #[tokio::test]
309    async fn test_flashcards_schema() {
310        let dir = TempDir::new().unwrap();
311        let tool = FlashcardsTool::new(dir.path().to_path_buf());
312        assert_eq!(tool.name(), "flashcards");
313    }
314}