1use crate::Result;
34use ankit::AnkiClient;
35use serde::Serialize;
36use std::collections::HashMap;
37
38#[derive(Debug, Clone)]
40pub struct EnrichQuery {
41 pub search: String,
43 pub empty_fields: Vec<String>,
45}
46
47#[derive(Debug, Clone, Serialize)]
49pub struct EnrichCandidate {
50 pub note_id: i64,
52 pub model_name: String,
54 pub fields: HashMap<String, String>,
56 pub empty_fields: Vec<String>,
58 pub tags: Vec<String>,
60}
61
62#[derive(Debug, Clone, Default, Serialize)]
64pub struct EnrichReport {
65 pub updated: usize,
67 pub failed: usize,
69 pub failures: Vec<EnrichFailure>,
71}
72
73#[derive(Debug, Clone, Serialize)]
75pub struct EnrichFailure {
76 pub note_id: i64,
78 pub error: String,
80}
81
82#[derive(Debug)]
84pub struct EnrichEngine<'a> {
85 client: &'a AnkiClient,
86}
87
88impl<'a> EnrichEngine<'a> {
89 pub(crate) fn new(client: &'a AnkiClient) -> Self {
90 Self { client }
91 }
92
93 pub async fn find_candidates(&self, query: &EnrichQuery) -> Result<Vec<EnrichCandidate>> {
122 let note_ids = self.client.notes().find(&query.search).await?;
123
124 if note_ids.is_empty() {
125 return Ok(Vec::new());
126 }
127
128 let note_infos = self.client.notes().info(¬e_ids).await?;
129 let mut candidates = Vec::new();
130
131 for info in note_infos {
132 let empty: Vec<String> = query
134 .empty_fields
135 .iter()
136 .filter(|field_name| {
137 info.fields
138 .get(*field_name)
139 .map(|f| f.value.trim().is_empty())
140 .unwrap_or(true) })
142 .cloned()
143 .collect();
144
145 if !empty.is_empty() {
146 let fields: HashMap<String, String> =
148 info.fields.into_iter().map(|(k, v)| (k, v.value)).collect();
149
150 candidates.push(EnrichCandidate {
151 note_id: info.note_id,
152 model_name: info.model_name,
153 fields,
154 empty_fields: empty,
155 tags: info.tags,
156 });
157 }
158 }
159
160 Ok(candidates)
161 }
162
163 pub async fn update_note(&self, note_id: i64, fields: &HashMap<String, String>) -> Result<()> {
186 self.client.notes().update_fields(note_id, fields).await?;
187 Ok(())
188 }
189
190 pub async fn update_notes(
215 &self,
216 updates: &[(i64, HashMap<String, String>)],
217 ) -> Result<EnrichReport> {
218 let mut report = EnrichReport::default();
219
220 for (note_id, fields) in updates {
221 match self.client.notes().update_fields(*note_id, fields).await {
222 Ok(_) => report.updated += 1,
223 Err(e) => {
224 report.failed += 1;
225 report.failures.push(EnrichFailure {
226 note_id: *note_id,
227 error: e.to_string(),
228 });
229 }
230 }
231 }
232
233 Ok(report)
234 }
235
236 pub async fn tag_enriched(&self, note_ids: &[i64], tag: &str) -> Result<()> {
245 if !note_ids.is_empty() {
246 self.client.notes().add_tags(note_ids, tag).await?;
247 }
248 Ok(())
249 }
250
251 pub async fn pipeline(&self, query: &EnrichQuery) -> Result<EnrichmentPipeline> {
295 let candidates = self.find_candidates(query).await?;
296 Ok(EnrichmentPipeline::new(candidates))
297 }
298}
299
300#[derive(Debug, Clone)]
305pub struct EnrichmentPipeline {
306 candidates: Vec<EnrichCandidate>,
307 updates: HashMap<i64, HashMap<String, String>>,
308}
309
310impl EnrichmentPipeline {
311 pub fn new(candidates: Vec<EnrichCandidate>) -> Self {
313 Self {
314 candidates,
315 updates: HashMap::new(),
316 }
317 }
318
319 pub fn candidates(&self) -> &[EnrichCandidate] {
321 &self.candidates
322 }
323
324 pub fn len(&self) -> usize {
326 self.candidates.len()
327 }
328
329 pub fn is_empty(&self) -> bool {
331 self.candidates.is_empty()
332 }
333
334 pub fn by_missing_field(&self) -> HashMap<String, Vec<&EnrichCandidate>> {
344 let mut groups: HashMap<String, Vec<&EnrichCandidate>> = HashMap::new();
345
346 for candidate in &self.candidates {
347 for field in &candidate.empty_fields {
348 groups.entry(field.clone()).or_default().push(candidate);
349 }
350 }
351
352 groups
353 }
354
355 pub fn by_model(&self) -> HashMap<String, Vec<&EnrichCandidate>> {
359 let mut groups: HashMap<String, Vec<&EnrichCandidate>> = HashMap::new();
360
361 for candidate in &self.candidates {
362 groups
363 .entry(candidate.model_name.clone())
364 .or_default()
365 .push(candidate);
366 }
367
368 groups
369 }
370
371 pub fn update(&mut self, note_id: i64, fields: HashMap<String, String>) {
381 self.updates.entry(note_id).or_default().extend(fields);
382 }
383
384 pub fn pending_updates(&self) -> usize {
386 self.updates.len()
387 }
388
389 pub fn pending_candidates(&self) -> Vec<&EnrichCandidate> {
391 self.candidates
392 .iter()
393 .filter(|c| !self.updates.contains_key(&c.note_id))
394 .collect()
395 }
396
397 pub async fn commit(&self, engine: &crate::Engine) -> Result<EnrichPipelineReport> {
407 let skipped = self
409 .candidates
410 .iter()
411 .filter(|c| !self.updates.contains_key(&c.note_id))
412 .count();
413
414 let mut updated = 0;
415 let mut failed = Vec::new();
416
417 for (note_id, fields) in &self.updates {
419 match engine.enrich().update_note(*note_id, fields).await {
420 Ok(_) => updated += 1,
421 Err(e) => {
422 failed.push((*note_id, e.to_string()));
423 }
424 }
425 }
426
427 Ok(EnrichPipelineReport {
428 updated,
429 failed,
430 skipped,
431 })
432 }
433
434 pub async fn commit_and_tag(
441 &self,
442 engine: &crate::Engine,
443 tag: &str,
444 ) -> Result<EnrichPipelineReport> {
445 let report = self.commit(engine).await?;
446
447 if report.updated > 0 {
449 let updated_ids: Vec<i64> = self
450 .updates
451 .keys()
452 .filter(|id| !report.failed.iter().any(|(fid, _)| fid == *id))
453 .copied()
454 .collect();
455
456 if !updated_ids.is_empty() {
457 engine.enrich().tag_enriched(&updated_ids, tag).await?;
458 }
459 }
460
461 Ok(report)
462 }
463}
464
465#[derive(Debug, Clone, Default, Serialize)]
467pub struct EnrichPipelineReport {
468 pub updated: usize,
470 pub failed: Vec<(i64, String)>,
472 pub skipped: usize,
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 #[test]
481 fn test_enrich_query_construction() {
482 let query = EnrichQuery {
483 search: "deck:Test".to_string(),
484 empty_fields: vec!["Example".to_string(), "Audio".to_string()],
485 };
486
487 assert_eq!(query.search, "deck:Test");
488 assert_eq!(query.empty_fields.len(), 2);
489 assert!(query.empty_fields.contains(&"Example".to_string()));
490 }
491
492 #[test]
493 fn test_enrich_candidate_construction() {
494 let mut fields = HashMap::new();
495 fields.insert("Front".to_string(), "Hello".to_string());
496 fields.insert("Back".to_string(), "World".to_string());
497
498 let candidate = EnrichCandidate {
499 note_id: 12345,
500 model_name: "Basic".to_string(),
501 fields,
502 empty_fields: vec!["Example".to_string()],
503 tags: vec!["tag1".to_string()],
504 };
505
506 assert_eq!(candidate.note_id, 12345);
507 assert_eq!(candidate.model_name, "Basic");
508 assert_eq!(candidate.fields.len(), 2);
509 assert_eq!(candidate.empty_fields.len(), 1);
510 assert_eq!(candidate.tags.len(), 1);
511 }
512
513 #[test]
514 fn test_enrich_candidate_serialization() {
515 let candidate = EnrichCandidate {
516 note_id: 100,
517 model_name: "Vocab".to_string(),
518 fields: HashMap::new(),
519 empty_fields: vec!["Definition".to_string()],
520 tags: vec![],
521 };
522
523 let json = serde_json::to_string(&candidate).unwrap();
524 assert!(json.contains("\"note_id\":100"));
525 assert!(json.contains("\"model_name\":\"Vocab\""));
526 }
527
528 #[test]
529 fn test_enrich_report_default() {
530 let report = EnrichReport::default();
531 assert_eq!(report.updated, 0);
532 assert_eq!(report.failed, 0);
533 assert!(report.failures.is_empty());
534 }
535
536 #[test]
537 fn test_enrich_report_construction() {
538 let failure = EnrichFailure {
539 note_id: 999,
540 error: "Not found".to_string(),
541 };
542
543 let report = EnrichReport {
544 updated: 5,
545 failed: 1,
546 failures: vec![failure],
547 };
548
549 assert_eq!(report.updated, 5);
550 assert_eq!(report.failed, 1);
551 assert_eq!(report.failures.len(), 1);
552 assert_eq!(report.failures[0].note_id, 999);
553 }
554
555 #[test]
556 fn test_enrich_failure_construction() {
557 let failure = EnrichFailure {
558 note_id: 12345,
559 error: "Field not found".to_string(),
560 };
561
562 assert_eq!(failure.note_id, 12345);
563 assert_eq!(failure.error, "Field not found");
564 }
565
566 #[test]
567 fn test_enrich_failure_serialization() {
568 let failure = EnrichFailure {
569 note_id: 456,
570 error: "Connection error".to_string(),
571 };
572
573 let json = serde_json::to_string(&failure).unwrap();
574 assert!(json.contains("\"note_id\":456"));
575 assert!(json.contains("\"error\":\"Connection error\""));
576 }
577
578 #[test]
579 fn test_enrichment_pipeline_new_empty() {
580 let pipeline = EnrichmentPipeline::new(vec![]);
581 assert!(pipeline.is_empty());
582 assert_eq!(pipeline.len(), 0);
583 assert_eq!(pipeline.pending_updates(), 0);
584 }
585
586 #[test]
587 fn test_enrichment_pipeline_with_candidates() {
588 let candidate = EnrichCandidate {
589 note_id: 1,
590 model_name: "Basic".to_string(),
591 fields: HashMap::new(),
592 empty_fields: vec!["Back".to_string()],
593 tags: vec![],
594 };
595
596 let pipeline = EnrichmentPipeline::new(vec![candidate]);
597 assert!(!pipeline.is_empty());
598 assert_eq!(pipeline.len(), 1);
599 assert_eq!(pipeline.candidates().len(), 1);
600 }
601
602 #[test]
603 fn test_enrichment_pipeline_update() {
604 let candidate = EnrichCandidate {
605 note_id: 100,
606 model_name: "Basic".to_string(),
607 fields: HashMap::new(),
608 empty_fields: vec!["Back".to_string()],
609 tags: vec![],
610 };
611
612 let mut pipeline = EnrichmentPipeline::new(vec![candidate]);
613 assert_eq!(pipeline.pending_updates(), 0);
614
615 let mut fields = HashMap::new();
616 fields.insert("Back".to_string(), "Answer".to_string());
617 pipeline.update(100, fields);
618
619 assert_eq!(pipeline.pending_updates(), 1);
620 }
621
622 #[test]
623 fn test_enrichment_pipeline_update_merge() {
624 let mut pipeline = EnrichmentPipeline::new(vec![]);
625
626 let mut fields1 = HashMap::new();
627 fields1.insert("Field1".to_string(), "Value1".to_string());
628 pipeline.update(100, fields1);
629
630 let mut fields2 = HashMap::new();
631 fields2.insert("Field2".to_string(), "Value2".to_string());
632 pipeline.update(100, fields2);
633
634 assert_eq!(pipeline.pending_updates(), 1);
636 }
637
638 #[test]
639 fn test_enrichment_pipeline_pending_candidates() {
640 let candidates = vec![
641 EnrichCandidate {
642 note_id: 1,
643 model_name: "Basic".to_string(),
644 fields: HashMap::new(),
645 empty_fields: vec!["Back".to_string()],
646 tags: vec![],
647 },
648 EnrichCandidate {
649 note_id: 2,
650 model_name: "Basic".to_string(),
651 fields: HashMap::new(),
652 empty_fields: vec!["Back".to_string()],
653 tags: vec![],
654 },
655 ];
656
657 let mut pipeline = EnrichmentPipeline::new(candidates);
658 assert_eq!(pipeline.pending_candidates().len(), 2);
659
660 let mut fields = HashMap::new();
661 fields.insert("Back".to_string(), "Answer".to_string());
662 pipeline.update(1, fields);
663
664 assert_eq!(pipeline.pending_candidates().len(), 1);
665 assert_eq!(pipeline.pending_candidates()[0].note_id, 2);
666 }
667
668 #[test]
669 fn test_enrichment_pipeline_by_missing_field() {
670 let candidates = vec![
671 EnrichCandidate {
672 note_id: 1,
673 model_name: "Basic".to_string(),
674 fields: HashMap::new(),
675 empty_fields: vec!["Field1".to_string()],
676 tags: vec![],
677 },
678 EnrichCandidate {
679 note_id: 2,
680 model_name: "Basic".to_string(),
681 fields: HashMap::new(),
682 empty_fields: vec!["Field1".to_string(), "Field2".to_string()],
683 tags: vec![],
684 },
685 ];
686
687 let pipeline = EnrichmentPipeline::new(candidates);
688 let by_field = pipeline.by_missing_field();
689
690 assert_eq!(by_field.get("Field1").unwrap().len(), 2);
691 assert_eq!(by_field.get("Field2").unwrap().len(), 1);
692 }
693
694 #[test]
695 fn test_enrichment_pipeline_by_model() {
696 let candidates = vec![
697 EnrichCandidate {
698 note_id: 1,
699 model_name: "Basic".to_string(),
700 fields: HashMap::new(),
701 empty_fields: vec![],
702 tags: vec![],
703 },
704 EnrichCandidate {
705 note_id: 2,
706 model_name: "Cloze".to_string(),
707 fields: HashMap::new(),
708 empty_fields: vec![],
709 tags: vec![],
710 },
711 EnrichCandidate {
712 note_id: 3,
713 model_name: "Basic".to_string(),
714 fields: HashMap::new(),
715 empty_fields: vec![],
716 tags: vec![],
717 },
718 ];
719
720 let pipeline = EnrichmentPipeline::new(candidates);
721 let by_model = pipeline.by_model();
722
723 assert_eq!(by_model.get("Basic").unwrap().len(), 2);
724 assert_eq!(by_model.get("Cloze").unwrap().len(), 1);
725 }
726
727 #[test]
728 fn test_enrich_pipeline_report_default() {
729 let report = EnrichPipelineReport::default();
730 assert_eq!(report.updated, 0);
731 assert!(report.failed.is_empty());
732 assert_eq!(report.skipped, 0);
733 }
734
735 #[test]
736 fn test_enrich_pipeline_report_construction() {
737 let report = EnrichPipelineReport {
738 updated: 10,
739 failed: vec![(100, "Error".to_string())],
740 skipped: 5,
741 };
742
743 assert_eq!(report.updated, 10);
744 assert_eq!(report.failed.len(), 1);
745 assert_eq!(report.skipped, 5);
746 }
747
748 #[test]
749 fn test_enrich_pipeline_report_serialization() {
750 let report = EnrichPipelineReport {
751 updated: 3,
752 failed: vec![],
753 skipped: 2,
754 };
755
756 let json = serde_json::to_string(&report).unwrap();
757 assert!(json.contains("\"updated\":3"));
758 assert!(json.contains("\"skipped\":2"));
759 }
760}