crabtalk-proxy 0.0.6

HTTP proxy server for the crabtalk LLM API gateway
Documentation
use crabtalk_core::{BoxFuture, Error, KvPairs, Prefix, Storage};
use sqlx::{Row, SqlitePool, sqlite::SqlitePoolOptions};

pub struct SqliteStorage {
    pool: SqlitePool,
}

impl SqliteStorage {
    pub async fn open(url: &str) -> Result<Self, Error> {
        let pool = SqlitePoolOptions::new()
            .connect(url)
            .await
            .map_err(|e| Error::Internal(format!("sqlite open: {e}")))?;

        sqlx::query("CREATE TABLE IF NOT EXISTS kv (key BLOB PRIMARY KEY, value BLOB NOT NULL)")
            .execute(&pool)
            .await
            .map_err(|e| Error::Internal(format!("sqlite init: {e}")))?;

        sqlx::query(
            "CREATE TABLE IF NOT EXISTS counters (key BLOB PRIMARY KEY, value INTEGER NOT NULL DEFAULT 0)",
        )
        .execute(&pool)
        .await
        .map_err(|e| Error::Internal(format!("sqlite init: {e}")))?;

        Ok(Self { pool })
    }
}

impl Storage for SqliteStorage {
    fn get(&self, key: &[u8]) -> BoxFuture<'_, Result<Option<Vec<u8>>, Error>> {
        let key = key.to_vec();
        Box::pin(async move {
            let row = sqlx::query("SELECT value FROM kv WHERE key = ?")
                .bind(&key)
                .fetch_optional(&self.pool)
                .await
                .map_err(|e| Error::Internal(e.to_string()))?;
            Ok(row.map(|r| r.get::<Vec<u8>, _>("value")))
        })
    }

    fn set(&self, key: &[u8], value: Vec<u8>) -> BoxFuture<'_, Result<(), Error>> {
        let key = key.to_vec();
        Box::pin(async move {
            sqlx::query(
                "INSERT INTO kv (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = excluded.value",
            )
            .bind(&key)
            .bind(&value)
            .execute(&self.pool)
            .await
            .map_err(|e| Error::Internal(e.to_string()))?;
            Ok(())
        })
    }

    fn increment(&self, key: &[u8], delta: i64) -> BoxFuture<'_, Result<i64, Error>> {
        let key = key.to_vec();
        Box::pin(async move {
            let row = sqlx::query(
                "INSERT INTO counters (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value = value + excluded.value RETURNING value",
            )
            .bind(&key)
            .bind(delta)
            .fetch_one(&self.pool)
            .await
            .map_err(|e| Error::Internal(e.to_string()))?;
            Ok(row.get::<i64, _>("value"))
        })
    }

    fn list(&self, prefix: &Prefix) -> BoxFuture<'_, Result<KvPairs, Error>> {
        let prefix_vec = prefix.to_vec();
        let mut upper = prefix_vec.clone();
        if let Some(last) = upper.last_mut() {
            *last = last.wrapping_add(1);
        }

        Box::pin(async move {
            let mut result = Vec::new();

            let kv_rows = sqlx::query("SELECT key, value FROM kv WHERE key >= ? AND key < ?")
                .bind(&prefix_vec)
                .bind(&upper)
                .fetch_all(&self.pool)
                .await
                .map_err(|e| Error::Internal(e.to_string()))?;

            for row in &kv_rows {
                result.push((row.get::<Vec<u8>, _>("key"), row.get::<Vec<u8>, _>("value")));
            }

            let counter_rows =
                sqlx::query("SELECT key, value FROM counters WHERE key >= ? AND key < ?")
                    .bind(&prefix_vec)
                    .bind(&upper)
                    .fetch_all(&self.pool)
                    .await
                    .map_err(|e| Error::Internal(e.to_string()))?;

            let seen: std::collections::HashSet<Vec<u8>> =
                result.iter().map(|(k, _)| k.clone()).collect();
            for row in &counter_rows {
                let k: Vec<u8> = row.get("key");
                let v: i64 = row.get("value");
                if !seen.contains(&k) {
                    result.push((k, v.to_le_bytes().to_vec()));
                }
            }

            Ok(result)
        })
    }

    fn delete(&self, key: &[u8]) -> BoxFuture<'_, Result<(), Error>> {
        let key = key.to_vec();
        Box::pin(async move {
            sqlx::query("DELETE FROM kv WHERE key = ?")
                .bind(&key)
                .execute(&self.pool)
                .await
                .map_err(|e| Error::Internal(e.to_string()))?;
            sqlx::query("DELETE FROM counters WHERE key = ?")
                .bind(&key)
                .execute(&self.pool)
                .await
                .map_err(|e| Error::Internal(e.to_string()))?;
            Ok(())
        })
    }
}