use crate::{
anthropic::{MessageBatch, MessageBatchResult},
openai::response::ResponseObject,
};
use serde_json::Value;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct StoredResponse {
pub response: ResponseObject,
pub input_items: Vec<Value>,
}
#[derive(Debug, Clone, Default)]
pub struct ResponseStore {
inner: Arc<RwLock<HashMap<String, StoredResponse>>>,
}
impl ResponseStore {
pub async fn insert(&self, stored: StoredResponse) {
self.inner
.write()
.await
.insert(stored.response.id.clone(), stored);
}
#[must_use]
pub async fn get(&self, id: &str) -> Option<StoredResponse> {
self.inner.read().await.get(id).cloned()
}
}
#[derive(Debug, Clone)]
pub struct StoredBatch {
pub batch: MessageBatch,
pub results: Vec<MessageBatchResult>,
pub cancel_requested: bool,
}
#[derive(Debug, Clone, Default)]
pub struct BatchStore {
inner: Arc<RwLock<HashMap<String, StoredBatch>>>,
}
impl BatchStore {
pub async fn insert(&self, stored: StoredBatch) {
self.inner
.write()
.await
.insert(stored.batch.id.clone(), stored);
}
#[must_use]
pub async fn get(&self, id: &str) -> Option<StoredBatch> {
self.inner.read().await.get(id).cloned()
}
#[must_use]
pub async fn list(&self) -> Vec<StoredBatch> {
let mut batches = self
.inner
.read()
.await
.values()
.cloned()
.collect::<Vec<_>>();
batches.sort_by(|left, right| right.batch.created_at.cmp(&left.batch.created_at));
batches
}
pub async fn update<F>(&self, id: &str, update: F) -> Option<StoredBatch>
where
F: FnOnce(&mut StoredBatch),
{
let mut guard = self.inner.write().await;
let stored = guard.get_mut(id)?;
update(stored);
let updated = stored.clone();
drop(guard);
Some(updated)
}
pub async fn remove(&self, id: &str) -> Option<StoredBatch> {
self.inner.write().await.remove(id)
}
#[must_use]
pub async fn cancel_requested(&self, id: &str) -> Option<bool> {
self.inner
.read()
.await
.get(id)
.map(|stored| stored.cancel_requested)
}
}