Skip to main content

openai_agents/session/
sqlite.rs

1//! SQLite session implementation
2
3use async_trait::async_trait;
4use serde_json::Value;
5use sqlx::{Row, sqlite::SqlitePool};
6
7use crate::error::{AgentError, Result};
8
9use super::Session;
10
11/// SQLite-based session storage
12pub struct SqliteSession {
13    session_id: String,
14    pool: SqlitePool,
15}
16
17impl SqliteSession {
18    /// Create a new SQLite session
19    pub async fn new(session_id: impl Into<String>, db_path: impl AsRef<str>) -> Result<Self> {
20        let pool = SqlitePool::connect(db_path.as_ref())
21            .await
22            .map_err(|e| AgentError::SessionError(e.to_string()))?;
23
24        // Create table if it doesn't exist
25        sqlx::query(
26            r#"
27            CREATE TABLE IF NOT EXISTS sessions (
28                session_id TEXT NOT NULL,
29                item_index INTEGER NOT NULL,
30                item_data TEXT NOT NULL,
31                created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
32                PRIMARY KEY (session_id, item_index)
33            )
34            "#,
35        )
36        .execute(&pool)
37        .await?;
38
39        Ok(Self {
40            session_id: session_id.into(),
41            pool,
42        })
43    }
44}
45
46#[async_trait]
47impl Session for SqliteSession {
48    async fn get_items(&self, limit: Option<usize>) -> Result<Vec<Value>> {
49        let query = if let Some(limit) = limit {
50            format!(
51                "SELECT item_data FROM sessions WHERE session_id = ? ORDER BY item_index DESC LIMIT {}",
52                limit
53            )
54        } else {
55            "SELECT item_data FROM sessions WHERE session_id = ? ORDER BY item_index ASC"
56                .to_string()
57        };
58
59        let rows = sqlx::query(&query)
60            .bind(&self.session_id)
61            .fetch_all(&self.pool)
62            .await?;
63
64        let mut items = Vec::new();
65        for row in rows {
66            let data: String = row.try_get("item_data")?;
67            let value: Value = serde_json::from_str(&data)?;
68            items.push(value);
69        }
70
71        Ok(items)
72    }
73
74    async fn add_items(&self, items: Vec<Value>) -> Result<()> {
75        for item in items {
76            let data = serde_json::to_string(&item)?;
77
78            // Get the next index
79            let next_index: i64 = sqlx::query_scalar(
80                "SELECT COALESCE(MAX(item_index), -1) + 1 FROM sessions WHERE session_id = ?",
81            )
82            .bind(&self.session_id)
83            .fetch_one(&self.pool)
84            .await?;
85
86            sqlx::query(
87                "INSERT INTO sessions (session_id, item_index, item_data) VALUES (?, ?, ?)",
88            )
89            .bind(&self.session_id)
90            .bind(next_index)
91            .bind(data)
92            .execute(&self.pool)
93            .await?;
94        }
95
96        Ok(())
97    }
98
99    async fn pop_item(&self) -> Result<Option<Value>> {
100        let row: Option<(i64, String)> = sqlx::query_as(
101            "SELECT item_index, item_data FROM sessions WHERE session_id = ? ORDER BY item_index DESC LIMIT 1"
102        )
103        .bind(&self.session_id)
104        .fetch_optional(&self.pool)
105        .await?;
106
107        if let Some((index, data)) = row {
108            sqlx::query("DELETE FROM sessions WHERE session_id = ? AND item_index = ?")
109                .bind(&self.session_id)
110                .bind(index)
111                .execute(&self.pool)
112                .await?;
113
114            let value: Value = serde_json::from_str(&data)?;
115            Ok(Some(value))
116        } else {
117            Ok(None)
118        }
119    }
120
121    async fn clear_session(&self) -> Result<()> {
122        sqlx::query("DELETE FROM sessions WHERE session_id = ?")
123            .bind(&self.session_id)
124            .execute(&self.pool)
125            .await?;
126
127        Ok(())
128    }
129}