scitadel_db/sqlite/
assessments.rs1use rusqlite::params;
2use scitadel_core::error::CoreError;
3use scitadel_core::models::{Assessment, AssessmentId, PaperId, QuestionId};
4use scitadel_core::ports::AssessmentRepository;
5
6use super::Database;
7use crate::error::DbError;
8
9pub struct SqliteAssessmentRepository {
10 db: Database,
11}
12
13impl SqliteAssessmentRepository {
14 pub fn new(db: Database) -> Self {
15 Self { db }
16 }
17}
18
19fn row_to_assessment(row: &rusqlite::Row) -> rusqlite::Result<Assessment> {
20 let id: String = row.get("id")?;
21 let paper_id: String = row.get("paper_id")?;
22 let question_id: String = row.get("question_id")?;
23 let created_at: String = row.get("created_at")?;
24
25 Ok(Assessment {
26 id: AssessmentId::from(id),
27 paper_id: PaperId::from(paper_id),
28 question_id: QuestionId::from(question_id),
29 score: row.get("score")?,
30 reasoning: row.get("reasoning")?,
31 model: row.get("model")?,
32 prompt: row.get("prompt")?,
33 temperature: row.get("temperature")?,
34 assessor: row.get("assessor")?,
35 created_at: super::parse_rfc3339_or_now(&created_at),
36 })
37}
38
39impl AssessmentRepository for SqliteAssessmentRepository {
40 fn save(&self, assessment: &Assessment) -> Result<(), CoreError> {
41 let conn = self.db.conn()?;
42 conn.execute(
43 "INSERT OR REPLACE INTO assessments
44 (id, paper_id, question_id, score, reasoning, model,
45 prompt, temperature, assessor, created_at)
46 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10)",
47 params![
48 assessment.id.as_str(),
49 assessment.paper_id.as_str(),
50 assessment.question_id.as_str(),
51 assessment.score,
52 assessment.reasoning,
53 assessment.model,
54 assessment.prompt,
55 assessment.temperature,
56 assessment.assessor,
57 assessment.created_at.to_rfc3339(),
58 ],
59 )
60 .map_err(DbError::Sqlite)?;
61 Ok(())
62 }
63
64 fn get_for_paper(
65 &self,
66 paper_id: &str,
67 question_id: Option<&str>,
68 ) -> Result<Vec<Assessment>, CoreError> {
69 let conn = self.db.conn()?;
70 let (sql, query_params): (&str, Vec<Box<dyn rusqlite::types::ToSql>>) =
71 if let Some(qid) = question_id {
72 (
73 "SELECT * FROM assessments WHERE paper_id = ?1 AND question_id = ?2",
74 vec![Box::new(paper_id.to_string()), Box::new(qid.to_string())],
75 )
76 } else {
77 (
78 "SELECT * FROM assessments WHERE paper_id = ?1",
79 vec![Box::new(paper_id.to_string())],
80 )
81 };
82 let mut stmt = conn.prepare(sql).map_err(DbError::Sqlite)?;
83 let params: Vec<&dyn rusqlite::types::ToSql> =
84 query_params.iter().map(|b| b.as_ref()).collect();
85 let assessments = stmt
86 .query_map(params.as_slice(), row_to_assessment)
87 .map_err(DbError::Sqlite)?
88 .filter_map(Result::ok)
89 .collect();
90 Ok(assessments)
91 }
92
93 fn get_for_question(&self, question_id: &str) -> Result<Vec<Assessment>, CoreError> {
94 let conn = self.db.conn()?;
95 let mut stmt = conn
96 .prepare("SELECT * FROM assessments WHERE question_id = ?1")
97 .map_err(DbError::Sqlite)?;
98 let assessments = stmt
99 .query_map(params![question_id], row_to_assessment)
100 .map_err(DbError::Sqlite)?
101 .filter_map(Result::ok)
102 .collect();
103 Ok(assessments)
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use crate::sqlite::{Database, SqlitePaperRepository, SqliteQuestionRepository};
111 use scitadel_core::models::{Paper, ResearchQuestion};
112 use scitadel_core::ports::{PaperRepository, QuestionRepository};
113
114 #[test]
115 fn test_assessment_crud() {
116 let db = Database::open_in_memory().unwrap();
117 db.migrate().unwrap();
118 let paper_repo = SqlitePaperRepository::new(db.clone());
119 let q_repo = SqliteQuestionRepository::new(db.clone());
120 let a_repo = SqliteAssessmentRepository::new(db);
121
122 let paper = Paper::new("Test Paper");
123 paper_repo.save(&paper).unwrap();
124
125 let q = ResearchQuestion::new("Test question");
126 q_repo.save_question(&q).unwrap();
127
128 let assessment = Assessment::new(paper.id.clone(), q.id.clone(), 0.85);
129 a_repo.save(&assessment).unwrap();
130
131 let loaded = a_repo
132 .get_for_paper(paper.id.as_str(), Some(q.id.as_str()))
133 .unwrap();
134 assert_eq!(loaded.len(), 1);
135 assert!((loaded[0].score - 0.85).abs() < f64::EPSILON);
136 }
137}