openai_agents/session/
sqlite.rs1use async_trait::async_trait;
4use serde_json::Value;
5use sqlx::{Row, sqlite::SqlitePool};
6
7use crate::error::{AgentError, Result};
8
9use super::Session;
10
11pub struct SqliteSession {
13 session_id: String,
14 pool: SqlitePool,
15}
16
17impl SqliteSession {
18 pub async fn new(session_id: impl Into<String>, db_path: impl AsRef<str>) -> Result<Self> {
20 let pool = SqlitePool::connect(db_path.as_ref())
21 .await
22 .map_err(|e| AgentError::SessionError(e.to_string()))?;
23
24 sqlx::query(
26 r#"
27 CREATE TABLE IF NOT EXISTS sessions (
28 session_id TEXT NOT NULL,
29 item_index INTEGER NOT NULL,
30 item_data TEXT NOT NULL,
31 created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
32 PRIMARY KEY (session_id, item_index)
33 )
34 "#,
35 )
36 .execute(&pool)
37 .await?;
38
39 Ok(Self {
40 session_id: session_id.into(),
41 pool,
42 })
43 }
44}
45
46#[async_trait]
47impl Session for SqliteSession {
48 async fn get_items(&self, limit: Option<usize>) -> Result<Vec<Value>> {
49 let query = if let Some(limit) = limit {
50 format!(
51 "SELECT item_data FROM sessions WHERE session_id = ? ORDER BY item_index DESC LIMIT {}",
52 limit
53 )
54 } else {
55 "SELECT item_data FROM sessions WHERE session_id = ? ORDER BY item_index ASC"
56 .to_string()
57 };
58
59 let rows = sqlx::query(&query)
60 .bind(&self.session_id)
61 .fetch_all(&self.pool)
62 .await?;
63
64 let mut items = Vec::new();
65 for row in rows {
66 let data: String = row.try_get("item_data")?;
67 let value: Value = serde_json::from_str(&data)?;
68 items.push(value);
69 }
70
71 Ok(items)
72 }
73
74 async fn add_items(&self, items: Vec<Value>) -> Result<()> {
75 for item in items {
76 let data = serde_json::to_string(&item)?;
77
78 let next_index: i64 = sqlx::query_scalar(
80 "SELECT COALESCE(MAX(item_index), -1) + 1 FROM sessions WHERE session_id = ?",
81 )
82 .bind(&self.session_id)
83 .fetch_one(&self.pool)
84 .await?;
85
86 sqlx::query(
87 "INSERT INTO sessions (session_id, item_index, item_data) VALUES (?, ?, ?)",
88 )
89 .bind(&self.session_id)
90 .bind(next_index)
91 .bind(data)
92 .execute(&self.pool)
93 .await?;
94 }
95
96 Ok(())
97 }
98
99 async fn pop_item(&self) -> Result<Option<Value>> {
100 let row: Option<(i64, String)> = sqlx::query_as(
101 "SELECT item_index, item_data FROM sessions WHERE session_id = ? ORDER BY item_index DESC LIMIT 1"
102 )
103 .bind(&self.session_id)
104 .fetch_optional(&self.pool)
105 .await?;
106
107 if let Some((index, data)) = row {
108 sqlx::query("DELETE FROM sessions WHERE session_id = ? AND item_index = ?")
109 .bind(&self.session_id)
110 .bind(index)
111 .execute(&self.pool)
112 .await?;
113
114 let value: Value = serde_json::from_str(&data)?;
115 Ok(Some(value))
116 } else {
117 Ok(None)
118 }
119 }
120
121 async fn clear_session(&self) -> Result<()> {
122 sqlx::query("DELETE FROM sessions WHERE session_id = ?")
123 .bind(&self.session_id)
124 .execute(&self.pool)
125 .await?;
126
127 Ok(())
128 }
129}