use async_trait::async_trait;
use serde::Deserialize;
use std::path::PathBuf;
use tirea_contract::storage::{
Committed, ThreadHead, ThreadListPage, ThreadListQuery, ThreadReader, ThreadStoreError,
ThreadWriter, VersionPrecondition,
};
use tirea_contract::{Thread, ThreadChangeSet, Version};
use tokio::io::AsyncWriteExt;
pub struct FileStore {
base_path: PathBuf,
}
impl FileStore {
pub fn new(base_path: impl Into<PathBuf>) -> Self {
Self {
base_path: base_path.into(),
}
}
pub(super) fn thread_path(&self, thread_id: &str) -> Result<PathBuf, ThreadStoreError> {
Self::validate_thread_id(thread_id)?;
Ok(self.base_path.join(format!("{}.json", thread_id)))
}
fn validate_thread_id(thread_id: &str) -> Result<(), ThreadStoreError> {
if thread_id.is_empty() {
return Err(ThreadStoreError::InvalidId(
"thread id cannot be empty".to_string(),
));
}
if thread_id.contains('/')
|| thread_id.contains('\\')
|| thread_id.contains("..")
|| thread_id.contains('\0')
{
return Err(ThreadStoreError::InvalidId(format!(
"thread id contains invalid characters: {thread_id:?}"
)));
}
if thread_id.chars().any(|c| c.is_control()) {
return Err(ThreadStoreError::InvalidId(format!(
"thread id contains control characters: {thread_id:?}"
)));
}
Ok(())
}
}
#[async_trait]
impl ThreadWriter for FileStore {
async fn create(&self, thread: &Thread) -> Result<Committed, ThreadStoreError> {
let path = self.thread_path(&thread.id)?;
if path.exists() {
return Err(ThreadStoreError::AlreadyExists);
}
let head = ThreadHead {
thread: thread.clone(),
version: 0,
};
self.save_head(&head).await?;
Ok(Committed { version: 0 })
}
async fn append(
&self,
thread_id: &str,
delta: &ThreadChangeSet,
precondition: VersionPrecondition,
) -> Result<Committed, ThreadStoreError> {
let head = self
.load_head(thread_id)
.await?
.ok_or_else(|| ThreadStoreError::NotFound(thread_id.to_string()))?;
if let VersionPrecondition::Exact(expected) = precondition {
if head.version != expected {
return Err(ThreadStoreError::VersionConflict {
expected,
actual: head.version,
});
}
}
let mut thread = head.thread;
delta.apply_to(&mut thread);
let new_version = head.version + 1;
let new_head = ThreadHead {
thread,
version: new_version,
};
self.save_head(&new_head).await?;
Ok(Committed {
version: new_version,
})
}
async fn delete(&self, thread_id: &str) -> Result<(), ThreadStoreError> {
let path = self.thread_path(thread_id)?;
if path.exists() {
tokio::fs::remove_file(&path).await?;
}
Ok(())
}
async fn save(&self, thread: &Thread) -> Result<(), ThreadStoreError> {
let next_version = self
.load_head(&thread.id)
.await?
.map_or(0, |head| head.version.saturating_add(1));
let head = ThreadHead {
thread: thread.clone(),
version: next_version,
};
self.save_head(&head).await
}
}
#[async_trait]
impl ThreadReader for FileStore {
async fn load(&self, thread_id: &str) -> Result<Option<ThreadHead>, ThreadStoreError> {
self.load_head(thread_id).await
}
async fn list_threads(
&self,
query: &ThreadListQuery,
) -> Result<ThreadListPage, ThreadStoreError> {
let mut all = if !self.base_path.exists() {
Vec::new()
} else {
let mut entries = tokio::fs::read_dir(&self.base_path).await?;
let mut ids = Vec::new();
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "json") {
if let Some(stem) = path.file_stem() {
if let Some(id) = stem.to_str() {
ids.push(id.to_string());
}
}
}
}
ids
};
if let Some(ref resource_id) = query.resource_id {
let mut filtered = Vec::new();
for id in &all {
if let Some(head) = self.load(id).await? {
if head.thread.resource_id.as_deref() == Some(resource_id.as_str()) {
filtered.push(id.clone());
}
}
}
all = filtered;
}
if let Some(ref parent_thread_id) = query.parent_thread_id {
let mut filtered = Vec::new();
for id in &all {
if let Some(head) = self.load(id).await? {
if head.thread.parent_thread_id.as_deref() == Some(parent_thread_id.as_str()) {
filtered.push(id.clone());
}
}
}
all = filtered;
}
all.sort();
let total = all.len();
let limit = query.limit.clamp(1, 200);
let offset = query.offset.min(total);
let end = (offset + limit + 1).min(total);
let slice = &all[offset..end];
let has_more = slice.len() > limit;
let items: Vec<String> = slice.iter().take(limit).cloned().collect();
Ok(ThreadListPage {
items,
total,
has_more,
})
}
}
impl FileStore {
async fn load_head(&self, thread_id: &str) -> Result<Option<ThreadHead>, ThreadStoreError> {
let path = self.thread_path(thread_id)?;
if !path.exists() {
return Ok(None);
}
let content = tokio::fs::read_to_string(&path).await?;
if let Ok(head) = serde_json::from_str::<VersionedThread>(&content) {
let thread: Thread = serde_json::from_str(&content)
.map_err(|e| ThreadStoreError::Serialization(e.to_string()))?;
Ok(Some(ThreadHead {
thread,
version: head._version.unwrap_or(0),
}))
} else {
let thread: Thread = serde_json::from_str(&content)
.map_err(|e| ThreadStoreError::Serialization(e.to_string()))?;
Ok(Some(ThreadHead { thread, version: 0 }))
}
}
async fn save_head(&self, head: &ThreadHead) -> Result<(), ThreadStoreError> {
if !self.base_path.exists() {
tokio::fs::create_dir_all(&self.base_path).await?;
}
let path = self.thread_path(&head.thread.id)?;
let mut v = serde_json::to_value(&head.thread)
.map_err(|e| ThreadStoreError::Serialization(e.to_string()))?;
if let Some(obj) = v.as_object_mut() {
obj.insert("_version".to_string(), serde_json::json!(head.version));
}
let content = serde_json::to_string_pretty(&v)
.map_err(|e| ThreadStoreError::Serialization(e.to_string()))?;
let tmp_path = self.base_path.join(format!(
".{}.{}.tmp",
head.thread.id,
uuid::Uuid::new_v4().simple()
));
let write_result = async {
let mut file = tokio::fs::File::create(&tmp_path).await?;
file.write_all(content.as_bytes()).await?;
file.flush().await?;
file.sync_all().await?;
drop(file);
match tokio::fs::rename(&tmp_path, &path).await {
Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
tokio::fs::remove_file(&path).await?;
tokio::fs::rename(&tmp_path, &path).await?;
}
Err(e) => return Err(e),
}
Ok::<(), std::io::Error>(())
}
.await;
if let Err(e) = write_result {
let _ = tokio::fs::remove_file(&tmp_path).await;
return Err(ThreadStoreError::Io(e));
}
Ok(())
}
}
#[derive(Deserialize)]
struct VersionedThread {
#[serde(default)]
_version: Option<Version>,
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::sync::Arc;
use tempfile::TempDir;
use tirea_contract::{
storage::ThreadReader, CheckpointReason, Message, MessageQuery, ThreadWriter,
};
use tirea_state::{path, Op, Patch, TrackedPatch};
fn make_thread_with_messages(thread_id: &str, n: usize) -> Thread {
let mut thread = Thread::new(thread_id);
for i in 0..n {
thread = thread.with_message(Message::user(format!("msg-{i}")));
}
thread
}
#[tokio::test]
async fn file_storage_save_load_roundtrip() {
let temp_dir = TempDir::new().unwrap();
let storage = FileStore::new(temp_dir.path());
let thread = Thread::new("test-1").with_message(Message::user("hello"));
storage.save(&thread).await.unwrap();
let loaded = storage.load_thread("test-1").await.unwrap().unwrap();
assert_eq!(loaded.id, "test-1");
assert_eq!(loaded.message_count(), 1);
}
#[tokio::test]
async fn file_storage_list_and_delete() {
let temp_dir = TempDir::new().unwrap();
let storage = FileStore::new(temp_dir.path());
storage.create(&Thread::new("thread-a")).await.unwrap();
storage.create(&Thread::new("thread-b")).await.unwrap();
storage.create(&Thread::new("thread-c")).await.unwrap();
let mut ids = storage.list().await.unwrap();
ids.sort();
assert_eq!(ids, vec!["thread-a", "thread-b", "thread-c"]);
storage.delete("thread-b").await.unwrap();
let mut ids = storage.list().await.unwrap();
ids.sort();
assert_eq!(ids, vec!["thread-a", "thread-c"]);
}
#[tokio::test]
async fn file_storage_message_queries() {
let temp_dir = TempDir::new().unwrap();
let storage = FileStore::new(temp_dir.path());
let thread = make_thread_with_messages("t1", 10);
storage.save(&thread).await.unwrap();
let page = storage
.load_messages(
"t1",
&MessageQuery {
after: Some(4),
limit: 3,
..Default::default()
},
)
.await
.unwrap();
assert_eq!(page.messages.len(), 3);
assert_eq!(page.messages[0].cursor, 5);
assert_eq!(page.messages[0].message.content, "msg-5");
assert_eq!(storage.message_count("t1").await.unwrap(), 10);
}
#[tokio::test]
async fn file_storage_append_and_versioning() {
let temp_dir = TempDir::new().unwrap();
let store = FileStore::new(temp_dir.path());
store.create(&Thread::new("t1")).await.unwrap();
let d1 = ThreadChangeSet {
run_id: "run-1".to_string(),
parent_run_id: None,
reason: CheckpointReason::UserMessage,
messages: vec![Arc::new(Message::user("hello"))],
patches: vec![],
snapshot: None,
};
let c1 = store
.append("t1", &d1, VersionPrecondition::Exact(0))
.await
.unwrap();
assert_eq!(c1.version, 1);
let d2 = ThreadChangeSet {
run_id: "run-1".to_string(),
parent_run_id: None,
reason: CheckpointReason::AssistantTurnCommitted,
messages: vec![Arc::new(Message::assistant("hi"))],
patches: vec![TrackedPatch::new(
Patch::new().with_op(Op::set(path!("greeted"), json!(true))),
)],
snapshot: None,
};
let c2 = store
.append("t1", &d2, VersionPrecondition::Exact(1))
.await
.unwrap();
assert_eq!(c2.version, 2);
let d3 = ThreadChangeSet {
run_id: "run-1".to_string(),
parent_run_id: None,
reason: CheckpointReason::RunFinished,
messages: vec![],
patches: vec![],
snapshot: Some(json!({"greeted": true})),
};
let c3 = store
.append("t1", &d3, VersionPrecondition::Exact(2))
.await
.unwrap();
assert_eq!(c3.version, 3);
let store2 = FileStore::new(temp_dir.path());
let head = store2.load("t1").await.unwrap().unwrap();
assert_eq!(head.version, 3);
assert_eq!(head.thread.message_count(), 2);
assert!(head.thread.patches.is_empty());
assert_eq!(head.thread.state, json!({"greeted": true}));
}
#[test]
fn file_storage_rejects_path_traversal() {
let storage = FileStore::new("/base/path");
assert!(storage.thread_path("../../etc/passwd").is_err());
assert!(storage.thread_path("foo/bar").is_err());
assert!(storage.thread_path("foo\\bar").is_err());
assert!(storage.thread_path("").is_err());
assert!(storage.thread_path("foo\0bar").is_err());
}
}