Skip to main content

ankit_engine/
enrich.rs

1//! Note enrichment operations.
2//!
3//! This module provides workflows for finding notes with empty fields
4//! and updating them with new content.
5//!
6//! # Example
7//!
8//! ```no_run
9//! use ankit_engine::Engine;
10//! use ankit_engine::enrich::EnrichQuery;
11//!
12//! # async fn example() -> ankit_engine::Result<()> {
13//! let engine = Engine::new();
14//!
15//! // Find notes missing the "Example" field
16//! let query = EnrichQuery {
17//!     search: "deck:Japanese".to_string(),
18//!     empty_fields: vec!["Example".to_string()],
19//! };
20//!
21//! let candidates = engine.enrich().find_candidates(&query).await?;
22//! println!("Found {} notes needing enrichment", candidates.len());
23//!
24//! // Update a note with enriched content
25//! use std::collections::HashMap;
26//! let mut updates = HashMap::new();
27//! updates.insert("Example".to_string(), "New example sentence".to_string());
28//! engine.enrich().update_note(candidates[0].note_id, &updates).await?;
29//! # Ok(())
30//! # }
31//! ```
32
33use crate::Result;
34use ankit::AnkiClient;
35use serde::Serialize;
36use std::collections::HashMap;
37
38/// Query parameters for finding notes to enrich.
39#[derive(Debug, Clone)]
40pub struct EnrichQuery {
41    /// Anki search query to filter notes.
42    pub search: String,
43    /// Field names that should be empty (any of these being empty qualifies the note).
44    pub empty_fields: Vec<String>,
45}
46
47/// A note that is a candidate for enrichment.
48#[derive(Debug, Clone, Serialize)]
49pub struct EnrichCandidate {
50    /// The note ID.
51    pub note_id: i64,
52    /// The model (note type) name.
53    pub model_name: String,
54    /// Current field values.
55    pub fields: HashMap<String, String>,
56    /// Fields that are empty and need enrichment.
57    pub empty_fields: Vec<String>,
58    /// Current tags on the note.
59    pub tags: Vec<String>,
60}
61
62/// Report from a batch enrichment operation.
63#[derive(Debug, Clone, Default, Serialize)]
64pub struct EnrichReport {
65    /// Number of notes successfully updated.
66    pub updated: usize,
67    /// Number of notes that failed to update.
68    pub failed: usize,
69    /// Details about failed updates.
70    pub failures: Vec<EnrichFailure>,
71}
72
73/// Details about a failed enrichment.
74#[derive(Debug, Clone, Serialize)]
75pub struct EnrichFailure {
76    /// The note ID that failed.
77    pub note_id: i64,
78    /// The error message.
79    pub error: String,
80}
81
82/// Enrichment workflow engine.
83#[derive(Debug)]
84pub struct EnrichEngine<'a> {
85    client: &'a AnkiClient,
86}
87
88impl<'a> EnrichEngine<'a> {
89    pub(crate) fn new(client: &'a AnkiClient) -> Self {
90        Self { client }
91    }
92
93    /// Find notes that have empty fields matching the query criteria.
94    ///
95    /// Returns a list of candidates with information about which fields need enrichment.
96    ///
97    /// # Arguments
98    ///
99    /// * `query` - Query parameters specifying search filter and fields to check
100    ///
101    /// # Example
102    ///
103    /// ```no_run
104    /// # use ankit_engine::Engine;
105    /// # use ankit_engine::enrich::EnrichQuery;
106    /// # async fn example() -> ankit_engine::Result<()> {
107    /// let engine = Engine::new();
108    ///
109    /// let query = EnrichQuery {
110    ///     search: "deck:\"My Deck\" note:Basic".to_string(),
111    ///     empty_fields: vec!["Example".to_string(), "Pronunciation".to_string()],
112    /// };
113    ///
114    /// let candidates = engine.enrich().find_candidates(&query).await?;
115    /// for candidate in &candidates {
116    ///     println!("Note {} needs: {:?}", candidate.note_id, candidate.empty_fields);
117    /// }
118    /// # Ok(())
119    /// # }
120    /// ```
121    pub async fn find_candidates(&self, query: &EnrichQuery) -> Result<Vec<EnrichCandidate>> {
122        let note_ids = self.client.notes().find(&query.search).await?;
123
124        if note_ids.is_empty() {
125            return Ok(Vec::new());
126        }
127
128        let note_infos = self.client.notes().info(&note_ids).await?;
129        let mut candidates = Vec::new();
130
131        for info in note_infos {
132            // Check which specified fields are empty
133            let empty: Vec<String> = query
134                .empty_fields
135                .iter()
136                .filter(|field_name| {
137                    info.fields
138                        .get(*field_name)
139                        .map(|f| f.value.trim().is_empty())
140                        .unwrap_or(true) // Field doesn't exist = empty
141                })
142                .cloned()
143                .collect();
144
145            if !empty.is_empty() {
146                // Convert fields to simple HashMap
147                let fields: HashMap<String, String> =
148                    info.fields.into_iter().map(|(k, v)| (k, v.value)).collect();
149
150                candidates.push(EnrichCandidate {
151                    note_id: info.note_id,
152                    model_name: info.model_name,
153                    fields,
154                    empty_fields: empty,
155                    tags: info.tags,
156                });
157            }
158        }
159
160        Ok(candidates)
161    }
162
163    /// Update a single note with new field values.
164    ///
165    /// # Arguments
166    ///
167    /// * `note_id` - The note to update
168    /// * `fields` - Map of field name to new value
169    ///
170    /// # Example
171    ///
172    /// ```no_run
173    /// # use ankit_engine::Engine;
174    /// # use std::collections::HashMap;
175    /// # async fn example() -> ankit_engine::Result<()> {
176    /// let engine = Engine::new();
177    ///
178    /// let mut fields = HashMap::new();
179    /// fields.insert("Example".to_string(), "This is an example sentence.".to_string());
180    ///
181    /// engine.enrich().update_note(12345, &fields).await?;
182    /// # Ok(())
183    /// # }
184    /// ```
185    pub async fn update_note(&self, note_id: i64, fields: &HashMap<String, String>) -> Result<()> {
186        self.client.notes().update_fields(note_id, fields).await?;
187        Ok(())
188    }
189
190    /// Update multiple notes with new field values.
191    ///
192    /// # Arguments
193    ///
194    /// * `updates` - List of (note_id, fields) pairs to update
195    ///
196    /// # Example
197    ///
198    /// ```no_run
199    /// # use ankit_engine::Engine;
200    /// # use std::collections::HashMap;
201    /// # async fn example() -> ankit_engine::Result<()> {
202    /// let engine = Engine::new();
203    ///
204    /// let updates: Vec<(i64, HashMap<String, String>)> = vec![
205    ///     (12345, [("Example".to_string(), "Example 1".to_string())].into_iter().collect()),
206    ///     (12346, [("Example".to_string(), "Example 2".to_string())].into_iter().collect()),
207    /// ];
208    ///
209    /// let report = engine.enrich().update_notes(&updates).await?;
210    /// println!("Updated: {}, Failed: {}", report.updated, report.failed);
211    /// # Ok(())
212    /// # }
213    /// ```
214    pub async fn update_notes(
215        &self,
216        updates: &[(i64, HashMap<String, String>)],
217    ) -> Result<EnrichReport> {
218        let mut report = EnrichReport::default();
219
220        for (note_id, fields) in updates {
221            match self.client.notes().update_fields(*note_id, fields).await {
222                Ok(_) => report.updated += 1,
223                Err(e) => {
224                    report.failed += 1;
225                    report.failures.push(EnrichFailure {
226                        note_id: *note_id,
227                        error: e.to_string(),
228                    });
229                }
230            }
231        }
232
233        Ok(report)
234    }
235
236    /// Add a tag to notes after enrichment.
237    ///
238    /// Useful for marking notes as processed.
239    ///
240    /// # Arguments
241    ///
242    /// * `note_ids` - Notes to tag
243    /// * `tag` - Tag to add
244    pub async fn tag_enriched(&self, note_ids: &[i64], tag: &str) -> Result<()> {
245        if !note_ids.is_empty() {
246            self.client.notes().add_tags(note_ids, tag).await?;
247        }
248        Ok(())
249    }
250
251    /// Create an enrichment pipeline for batch processing.
252    ///
253    /// The pipeline finds candidates and provides helpers for grouping,
254    /// updating, and committing changes.
255    ///
256    /// # Arguments
257    ///
258    /// * `query` - Query parameters specifying search filter and fields to check
259    ///
260    /// # Example
261    ///
262    /// ```no_run
263    /// # use ankit_engine::Engine;
264    /// # use ankit_engine::enrich::EnrichQuery;
265    /// # async fn example() -> ankit_engine::Result<()> {
266    /// let engine = Engine::new();
267    ///
268    /// let query = EnrichQuery {
269    ///     search: "deck:Japanese".to_string(),
270    ///     empty_fields: vec!["Example".to_string(), "Pronunciation".to_string()],
271    /// };
272    ///
273    /// let mut pipeline = engine.enrich().pipeline(&query).await?;
274    ///
275    /// // Process by missing field for efficient batching
276    /// for (field, candidates) in pipeline.by_missing_field() {
277    ///     println!("Field '{}' needs {} notes enriched", field, candidates.len());
278    /// }
279    ///
280    /// // Buffer updates - collect IDs first to avoid borrow issues
281    /// let note_ids: Vec<i64> = pipeline.candidates().iter().map(|c| c.note_id).collect();
282    /// for note_id in note_ids {
283    ///     pipeline.update(note_id, [
284    ///         ("Example".to_string(), "Generated example".to_string())
285    ///     ].into_iter().collect());
286    /// }
287    ///
288    /// // Commit all updates
289    /// let report = pipeline.commit(&engine).await?;
290    /// println!("Updated {} notes", report.updated);
291    /// # Ok(())
292    /// # }
293    /// ```
294    pub async fn pipeline(&self, query: &EnrichQuery) -> Result<EnrichmentPipeline> {
295        let candidates = self.find_candidates(query).await?;
296        Ok(EnrichmentPipeline::new(candidates))
297    }
298}
299
300/// A pipeline for batch enrichment operations.
301///
302/// Provides helpers for grouping candidates by missing field,
303/// buffering updates, and committing them in a single operation.
304#[derive(Debug, Clone)]
305pub struct EnrichmentPipeline {
306    candidates: Vec<EnrichCandidate>,
307    updates: HashMap<i64, HashMap<String, String>>,
308}
309
310impl EnrichmentPipeline {
311    /// Create a new pipeline with the given candidates.
312    pub fn new(candidates: Vec<EnrichCandidate>) -> Self {
313        Self {
314            candidates,
315            updates: HashMap::new(),
316        }
317    }
318
319    /// Get the candidates for enrichment.
320    pub fn candidates(&self) -> &[EnrichCandidate] {
321        &self.candidates
322    }
323
324    /// Get the number of candidates.
325    pub fn len(&self) -> usize {
326        self.candidates.len()
327    }
328
329    /// Check if there are no candidates.
330    pub fn is_empty(&self) -> bool {
331        self.candidates.is_empty()
332    }
333
334    /// Get candidates grouped by which field they're missing.
335    ///
336    /// This is useful for batch processing where you want to generate
337    /// content for all notes missing a specific field at once.
338    ///
339    /// # Returns
340    ///
341    /// A map from field name to the candidates that are missing that field.
342    /// A candidate may appear in multiple groups if it's missing multiple fields.
343    pub fn by_missing_field(&self) -> HashMap<String, Vec<&EnrichCandidate>> {
344        let mut groups: HashMap<String, Vec<&EnrichCandidate>> = HashMap::new();
345
346        for candidate in &self.candidates {
347            for field in &candidate.empty_fields {
348                groups.entry(field.clone()).or_default().push(candidate);
349            }
350        }
351
352        groups
353    }
354
355    /// Get candidates grouped by model name.
356    ///
357    /// Useful when different models need different enrichment strategies.
358    pub fn by_model(&self) -> HashMap<String, Vec<&EnrichCandidate>> {
359        let mut groups: HashMap<String, Vec<&EnrichCandidate>> = HashMap::new();
360
361        for candidate in &self.candidates {
362            groups
363                .entry(candidate.model_name.clone())
364                .or_default()
365                .push(candidate);
366        }
367
368        groups
369    }
370
371    /// Buffer an update for a note.
372    ///
373    /// Updates are not applied until `commit()` is called.
374    /// Multiple updates to the same note will be merged.
375    ///
376    /// # Arguments
377    ///
378    /// * `note_id` - The note to update
379    /// * `fields` - Field values to set
380    pub fn update(&mut self, note_id: i64, fields: HashMap<String, String>) {
381        self.updates.entry(note_id).or_default().extend(fields);
382    }
383
384    /// Get the number of buffered updates.
385    pub fn pending_updates(&self) -> usize {
386        self.updates.len()
387    }
388
389    /// Get candidates that haven't been updated yet.
390    pub fn pending_candidates(&self) -> Vec<&EnrichCandidate> {
391        self.candidates
392            .iter()
393            .filter(|c| !self.updates.contains_key(&c.note_id))
394            .collect()
395    }
396
397    /// Commit all buffered updates.
398    ///
399    /// # Arguments
400    ///
401    /// * `engine` - The engine to use for committing
402    ///
403    /// # Returns
404    ///
405    /// A report with counts of updated, failed, and skipped notes.
406    pub async fn commit(&self, engine: &crate::Engine) -> Result<EnrichPipelineReport> {
407        // Count skipped (candidates without updates)
408        let skipped = self
409            .candidates
410            .iter()
411            .filter(|c| !self.updates.contains_key(&c.note_id))
412            .count();
413
414        let mut updated = 0;
415        let mut failed = Vec::new();
416
417        // Apply updates
418        for (note_id, fields) in &self.updates {
419            match engine.enrich().update_note(*note_id, fields).await {
420                Ok(_) => updated += 1,
421                Err(e) => {
422                    failed.push((*note_id, e.to_string()));
423                }
424            }
425        }
426
427        Ok(EnrichPipelineReport {
428            updated,
429            failed,
430            skipped,
431        })
432    }
433
434    /// Commit all buffered updates and tag the updated notes.
435    ///
436    /// # Arguments
437    ///
438    /// * `engine` - The engine to use for committing
439    /// * `tag` - Tag to add to successfully updated notes
440    pub async fn commit_and_tag(
441        &self,
442        engine: &crate::Engine,
443        tag: &str,
444    ) -> Result<EnrichPipelineReport> {
445        let report = self.commit(engine).await?;
446
447        // Tag successfully updated notes
448        if report.updated > 0 {
449            let updated_ids: Vec<i64> = self
450                .updates
451                .keys()
452                .filter(|id| !report.failed.iter().any(|(fid, _)| fid == *id))
453                .copied()
454                .collect();
455
456            if !updated_ids.is_empty() {
457                engine.enrich().tag_enriched(&updated_ids, tag).await?;
458            }
459        }
460
461        Ok(report)
462    }
463}
464
465/// Report from an enrichment pipeline commit.
466#[derive(Debug, Clone, Default, Serialize)]
467pub struct EnrichPipelineReport {
468    /// Number of notes successfully updated.
469    pub updated: usize,
470    /// Notes that failed to update (note_id, error message).
471    pub failed: Vec<(i64, String)>,
472    /// Number of candidates that were not updated (no update buffered).
473    pub skipped: usize,
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479
480    #[test]
481    fn test_enrich_query_construction() {
482        let query = EnrichQuery {
483            search: "deck:Test".to_string(),
484            empty_fields: vec!["Example".to_string(), "Audio".to_string()],
485        };
486
487        assert_eq!(query.search, "deck:Test");
488        assert_eq!(query.empty_fields.len(), 2);
489        assert!(query.empty_fields.contains(&"Example".to_string()));
490    }
491
492    #[test]
493    fn test_enrich_candidate_construction() {
494        let mut fields = HashMap::new();
495        fields.insert("Front".to_string(), "Hello".to_string());
496        fields.insert("Back".to_string(), "World".to_string());
497
498        let candidate = EnrichCandidate {
499            note_id: 12345,
500            model_name: "Basic".to_string(),
501            fields,
502            empty_fields: vec!["Example".to_string()],
503            tags: vec!["tag1".to_string()],
504        };
505
506        assert_eq!(candidate.note_id, 12345);
507        assert_eq!(candidate.model_name, "Basic");
508        assert_eq!(candidate.fields.len(), 2);
509        assert_eq!(candidate.empty_fields.len(), 1);
510        assert_eq!(candidate.tags.len(), 1);
511    }
512
513    #[test]
514    fn test_enrich_candidate_serialization() {
515        let candidate = EnrichCandidate {
516            note_id: 100,
517            model_name: "Vocab".to_string(),
518            fields: HashMap::new(),
519            empty_fields: vec!["Definition".to_string()],
520            tags: vec![],
521        };
522
523        let json = serde_json::to_string(&candidate).unwrap();
524        assert!(json.contains("\"note_id\":100"));
525        assert!(json.contains("\"model_name\":\"Vocab\""));
526    }
527
528    #[test]
529    fn test_enrich_report_default() {
530        let report = EnrichReport::default();
531        assert_eq!(report.updated, 0);
532        assert_eq!(report.failed, 0);
533        assert!(report.failures.is_empty());
534    }
535
536    #[test]
537    fn test_enrich_report_construction() {
538        let failure = EnrichFailure {
539            note_id: 999,
540            error: "Not found".to_string(),
541        };
542
543        let report = EnrichReport {
544            updated: 5,
545            failed: 1,
546            failures: vec![failure],
547        };
548
549        assert_eq!(report.updated, 5);
550        assert_eq!(report.failed, 1);
551        assert_eq!(report.failures.len(), 1);
552        assert_eq!(report.failures[0].note_id, 999);
553    }
554
555    #[test]
556    fn test_enrich_failure_construction() {
557        let failure = EnrichFailure {
558            note_id: 12345,
559            error: "Field not found".to_string(),
560        };
561
562        assert_eq!(failure.note_id, 12345);
563        assert_eq!(failure.error, "Field not found");
564    }
565
566    #[test]
567    fn test_enrich_failure_serialization() {
568        let failure = EnrichFailure {
569            note_id: 456,
570            error: "Connection error".to_string(),
571        };
572
573        let json = serde_json::to_string(&failure).unwrap();
574        assert!(json.contains("\"note_id\":456"));
575        assert!(json.contains("\"error\":\"Connection error\""));
576    }
577
578    #[test]
579    fn test_enrichment_pipeline_new_empty() {
580        let pipeline = EnrichmentPipeline::new(vec![]);
581        assert!(pipeline.is_empty());
582        assert_eq!(pipeline.len(), 0);
583        assert_eq!(pipeline.pending_updates(), 0);
584    }
585
586    #[test]
587    fn test_enrichment_pipeline_with_candidates() {
588        let candidate = EnrichCandidate {
589            note_id: 1,
590            model_name: "Basic".to_string(),
591            fields: HashMap::new(),
592            empty_fields: vec!["Back".to_string()],
593            tags: vec![],
594        };
595
596        let pipeline = EnrichmentPipeline::new(vec![candidate]);
597        assert!(!pipeline.is_empty());
598        assert_eq!(pipeline.len(), 1);
599        assert_eq!(pipeline.candidates().len(), 1);
600    }
601
602    #[test]
603    fn test_enrichment_pipeline_update() {
604        let candidate = EnrichCandidate {
605            note_id: 100,
606            model_name: "Basic".to_string(),
607            fields: HashMap::new(),
608            empty_fields: vec!["Back".to_string()],
609            tags: vec![],
610        };
611
612        let mut pipeline = EnrichmentPipeline::new(vec![candidate]);
613        assert_eq!(pipeline.pending_updates(), 0);
614
615        let mut fields = HashMap::new();
616        fields.insert("Back".to_string(), "Answer".to_string());
617        pipeline.update(100, fields);
618
619        assert_eq!(pipeline.pending_updates(), 1);
620    }
621
622    #[test]
623    fn test_enrichment_pipeline_update_merge() {
624        let mut pipeline = EnrichmentPipeline::new(vec![]);
625
626        let mut fields1 = HashMap::new();
627        fields1.insert("Field1".to_string(), "Value1".to_string());
628        pipeline.update(100, fields1);
629
630        let mut fields2 = HashMap::new();
631        fields2.insert("Field2".to_string(), "Value2".to_string());
632        pipeline.update(100, fields2);
633
634        // Should still be 1 update (merged)
635        assert_eq!(pipeline.pending_updates(), 1);
636    }
637
638    #[test]
639    fn test_enrichment_pipeline_pending_candidates() {
640        let candidates = vec![
641            EnrichCandidate {
642                note_id: 1,
643                model_name: "Basic".to_string(),
644                fields: HashMap::new(),
645                empty_fields: vec!["Back".to_string()],
646                tags: vec![],
647            },
648            EnrichCandidate {
649                note_id: 2,
650                model_name: "Basic".to_string(),
651                fields: HashMap::new(),
652                empty_fields: vec!["Back".to_string()],
653                tags: vec![],
654            },
655        ];
656
657        let mut pipeline = EnrichmentPipeline::new(candidates);
658        assert_eq!(pipeline.pending_candidates().len(), 2);
659
660        let mut fields = HashMap::new();
661        fields.insert("Back".to_string(), "Answer".to_string());
662        pipeline.update(1, fields);
663
664        assert_eq!(pipeline.pending_candidates().len(), 1);
665        assert_eq!(pipeline.pending_candidates()[0].note_id, 2);
666    }
667
668    #[test]
669    fn test_enrichment_pipeline_by_missing_field() {
670        let candidates = vec![
671            EnrichCandidate {
672                note_id: 1,
673                model_name: "Basic".to_string(),
674                fields: HashMap::new(),
675                empty_fields: vec!["Field1".to_string()],
676                tags: vec![],
677            },
678            EnrichCandidate {
679                note_id: 2,
680                model_name: "Basic".to_string(),
681                fields: HashMap::new(),
682                empty_fields: vec!["Field1".to_string(), "Field2".to_string()],
683                tags: vec![],
684            },
685        ];
686
687        let pipeline = EnrichmentPipeline::new(candidates);
688        let by_field = pipeline.by_missing_field();
689
690        assert_eq!(by_field.get("Field1").unwrap().len(), 2);
691        assert_eq!(by_field.get("Field2").unwrap().len(), 1);
692    }
693
694    #[test]
695    fn test_enrichment_pipeline_by_model() {
696        let candidates = vec![
697            EnrichCandidate {
698                note_id: 1,
699                model_name: "Basic".to_string(),
700                fields: HashMap::new(),
701                empty_fields: vec![],
702                tags: vec![],
703            },
704            EnrichCandidate {
705                note_id: 2,
706                model_name: "Cloze".to_string(),
707                fields: HashMap::new(),
708                empty_fields: vec![],
709                tags: vec![],
710            },
711            EnrichCandidate {
712                note_id: 3,
713                model_name: "Basic".to_string(),
714                fields: HashMap::new(),
715                empty_fields: vec![],
716                tags: vec![],
717            },
718        ];
719
720        let pipeline = EnrichmentPipeline::new(candidates);
721        let by_model = pipeline.by_model();
722
723        assert_eq!(by_model.get("Basic").unwrap().len(), 2);
724        assert_eq!(by_model.get("Cloze").unwrap().len(), 1);
725    }
726
727    #[test]
728    fn test_enrich_pipeline_report_default() {
729        let report = EnrichPipelineReport::default();
730        assert_eq!(report.updated, 0);
731        assert!(report.failed.is_empty());
732        assert_eq!(report.skipped, 0);
733    }
734
735    #[test]
736    fn test_enrich_pipeline_report_construction() {
737        let report = EnrichPipelineReport {
738            updated: 10,
739            failed: vec![(100, "Error".to_string())],
740            skipped: 5,
741        };
742
743        assert_eq!(report.updated, 10);
744        assert_eq!(report.failed.len(), 1);
745        assert_eq!(report.skipped, 5);
746    }
747
748    #[test]
749    fn test_enrich_pipeline_report_serialization() {
750        let report = EnrichPipelineReport {
751            updated: 3,
752            failed: vec![],
753            skipped: 2,
754        };
755
756        let json = serde_json::to_string(&report).unwrap();
757        assert!(json.contains("\"updated\":3"));
758        assert!(json.contains("\"skipped\":2"));
759    }
760}