Skip to main content

scitadel_db/sqlite/
assessments.rs

1use rusqlite::params;
2use scitadel_core::error::CoreError;
3use scitadel_core::models::{Assessment, AssessmentId, PaperId, QuestionId};
4use scitadel_core::ports::AssessmentRepository;
5
6use super::Database;
7use crate::error::DbError;
8
9pub struct SqliteAssessmentRepository {
10    db: Database,
11}
12
13impl SqliteAssessmentRepository {
14    pub fn new(db: Database) -> Self {
15        Self { db }
16    }
17}
18
19fn row_to_assessment(row: &rusqlite::Row) -> rusqlite::Result<Assessment> {
20    let id: String = row.get("id")?;
21    let paper_id: String = row.get("paper_id")?;
22    let question_id: String = row.get("question_id")?;
23    let created_at: String = row.get("created_at")?;
24
25    Ok(Assessment {
26        id: AssessmentId::from(id),
27        paper_id: PaperId::from(paper_id),
28        question_id: QuestionId::from(question_id),
29        score: row.get("score")?,
30        reasoning: row.get("reasoning")?,
31        model: row.get("model")?,
32        prompt: row.get("prompt")?,
33        temperature: row.get("temperature")?,
34        assessor: row.get("assessor")?,
35        created_at: super::parse_rfc3339_or_now(&created_at),
36    })
37}
38
39impl AssessmentRepository for SqliteAssessmentRepository {
40    fn save(&self, assessment: &Assessment) -> Result<(), CoreError> {
41        let conn = self.db.conn()?;
42        conn.execute(
43            "INSERT OR REPLACE INTO assessments
44                (id, paper_id, question_id, score, reasoning, model,
45                 prompt, temperature, assessor, created_at)
46             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
47            params![
48                assessment.id.as_str(),
49                assessment.paper_id.as_str(),
50                assessment.question_id.as_str(),
51                assessment.score,
52                assessment.reasoning,
53                assessment.model,
54                assessment.prompt,
55                assessment.temperature,
56                assessment.assessor,
57                assessment.created_at.to_rfc3339(),
58            ],
59        )
60        .map_err(DbError::Sqlite)?;
61        Ok(())
62    }
63
64    fn get_for_paper(
65        &self,
66        paper_id: &str,
67        question_id: Option<&str>,
68    ) -> Result<Vec<Assessment>, CoreError> {
69        let conn = self.db.conn()?;
70        let (sql, query_params): (&str, Vec<Box<dyn rusqlite::types::ToSql>>) =
71            if let Some(qid) = question_id {
72                (
73                    "SELECT * FROM assessments WHERE paper_id = ?1 AND question_id = ?2",
74                    vec![Box::new(paper_id.to_string()), Box::new(qid.to_string())],
75                )
76            } else {
77                (
78                    "SELECT * FROM assessments WHERE paper_id = ?1",
79                    vec![Box::new(paper_id.to_string())],
80                )
81            };
82        let mut stmt = conn.prepare(sql).map_err(DbError::Sqlite)?;
83        let params: Vec<&dyn rusqlite::types::ToSql> =
84            query_params.iter().map(|b| b.as_ref()).collect();
85        let assessments = stmt
86            .query_map(params.as_slice(), row_to_assessment)
87            .map_err(DbError::Sqlite)?
88            .filter_map(Result::ok)
89            .collect();
90        Ok(assessments)
91    }
92
93    fn get_for_question(&self, question_id: &str) -> Result<Vec<Assessment>, CoreError> {
94        let conn = self.db.conn()?;
95        let mut stmt = conn
96            .prepare("SELECT * FROM assessments WHERE question_id = ?1")
97            .map_err(DbError::Sqlite)?;
98        let assessments = stmt
99            .query_map(params![question_id], row_to_assessment)
100            .map_err(DbError::Sqlite)?
101            .filter_map(Result::ok)
102            .collect();
103        Ok(assessments)
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use crate::sqlite::{Database, SqlitePaperRepository, SqliteQuestionRepository};
111    use scitadel_core::models::{Paper, ResearchQuestion};
112    use scitadel_core::ports::{PaperRepository, QuestionRepository};
113
114    #[test]
115    fn test_assessment_crud() {
116        let db = Database::open_in_memory().unwrap();
117        db.migrate().unwrap();
118        let paper_repo = SqlitePaperRepository::new(db.clone());
119        let q_repo = SqliteQuestionRepository::new(db.clone());
120        let a_repo = SqliteAssessmentRepository::new(db);
121
122        let paper = Paper::new("Test Paper");
123        paper_repo.save(&paper).unwrap();
124
125        let q = ResearchQuestion::new("Test question");
126        q_repo.save_question(&q).unwrap();
127
128        let assessment = Assessment::new(paper.id.clone(), q.id.clone(), 0.85);
129        a_repo.save(&assessment).unwrap();
130
131        let loaded = a_repo
132            .get_for_paper(paper.id.as_str(), Some(q.id.as_str()))
133            .unwrap();
134        assert_eq!(loaded.len(), 1);
135        assert!((loaded[0].score - 0.85).abs() < f64::EPSILON);
136    }
137}