use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::SystemTime;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::error::{ServerError, ServerResult};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ResponseStatus {
InProgress,
Completed,
Failed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResponseRecord {
pub id: String,
pub object: String,
pub created_at: u64,
pub model: String,
pub status: ResponseStatus,
pub input: Vec<serde_json::Value>,
pub output: Option<String>,
pub previous_response_id: Option<String>,
pub instructions: Option<String>,
pub tools: Vec<serde_json::Value>,
}
impl ResponseRecord {
pub fn new_in_progress(
model: String,
input: Vec<serde_json::Value>,
previous_response_id: Option<String>,
instructions: Option<String>,
tools: Vec<serde_json::Value>,
) -> Self {
let id = format!("resp_{}", Uuid::new_v4().simple());
let created_at = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
Self {
id,
object: "response".to_string(),
created_at,
model,
status: ResponseStatus::InProgress,
input,
output: None,
previous_response_id,
instructions,
tools,
}
}
}
#[derive(Debug, Clone)]
pub struct ResponseStore {
records: Arc<RwLock<HashMap<String, ResponseRecord>>>,
}
impl Default for ResponseStore {
fn default() -> Self {
Self::new()
}
}
impl ResponseStore {
pub fn new() -> Self {
Self {
records: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn create(&self, rec: ResponseRecord) -> ServerResult<String> {
let id = rec.id.clone();
let mut map = self.records.write().map_err(|e| {
ServerError::FileStoreError(format!("response store lock poisoned: {e}"))
})?;
map.insert(id.clone(), rec);
Ok(id)
}
pub fn get(&self, id: &str) -> ServerResult<ResponseRecord> {
let map = self.records.read().map_err(|e| {
ServerError::FileStoreError(format!("response store lock poisoned: {e}"))
})?;
map.get(id)
.cloned()
.ok_or_else(|| ServerError::ResponseNotFound(id.to_string()))
}
pub fn update_output(
&self,
id: &str,
output: String,
status: ResponseStatus,
) -> ServerResult<()> {
let mut map = self.records.write().map_err(|e| {
ServerError::FileStoreError(format!("response store lock poisoned: {e}"))
})?;
let rec = map
.get_mut(id)
.ok_or_else(|| ServerError::ResponseNotFound(id.to_string()))?;
rec.output = Some(output);
rec.status = status;
Ok(())
}
pub fn list(&self) -> Vec<ResponseRecord> {
let map = self.records.read().unwrap_or_else(|e| e.into_inner());
let mut records: Vec<ResponseRecord> = map.values().cloned().collect();
records.sort_by_key(|b| std::cmp::Reverse(b.created_at));
records
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_record(model: &str) -> ResponseRecord {
ResponseRecord::new_in_progress(
model.to_string(),
vec![serde_json::json!({"role": "user", "content": "hello"})],
None,
None,
vec![],
)
}
#[test]
fn responses_create_returns_id() {
let store = ResponseStore::new();
let rec = make_record("gpt-test");
let id = store.create(rec).expect("create should succeed");
assert!(
id.starts_with("resp_"),
"id should start with 'resp_': {id}"
);
}
#[test]
fn responses_get_retrieves_record() {
let store = ResponseStore::new();
let rec = make_record("test-model");
let original_id = rec.id.clone();
let stored_id = store.create(rec).expect("create should succeed");
assert_eq!(stored_id, original_id);
let retrieved = store.get(&stored_id).expect("get should succeed");
assert_eq!(retrieved.id, stored_id);
assert_eq!(retrieved.model, "test-model");
assert_eq!(retrieved.status, ResponseStatus::InProgress);
assert!(retrieved.output.is_none());
}
#[test]
fn responses_unknown_id_returns_not_found() {
let store = ResponseStore::new();
let err = store.get("resp_does_not_exist").unwrap_err();
assert!(
matches!(err, ServerError::ResponseNotFound(_)),
"expected ResponseNotFound, got: {err:?}"
);
}
#[test]
fn responses_list_returns_descending() {
let store = ResponseStore::new();
for i in 0u64..3 {
let mut rec = make_record("model");
rec.created_at = i + 1; store.create(rec).expect("create");
}
let list = store.list();
assert_eq!(list.len(), 3);
assert!(
list[0].created_at >= list[1].created_at,
"list should be sorted descending: {:?}",
list.iter().map(|r| r.created_at).collect::<Vec<_>>()
);
assert!(
list[1].created_at >= list[2].created_at,
"list should be sorted descending"
);
}
#[test]
fn responses_update_output_changes_status() {
let store = ResponseStore::new();
let rec = make_record("m");
let id = store.create(rec).expect("create");
store
.update_output(&id, "hello world".to_string(), ResponseStatus::Completed)
.expect("update_output should succeed");
let updated = store.get(&id).expect("get after update");
assert_eq!(updated.status, ResponseStatus::Completed);
assert_eq!(updated.output.as_deref(), Some("hello world"));
}
}