use std::io::Write;
use std::path::Path;
use serde::{Deserialize, Serialize};
const FORMAT_VERSION: u32 = 1;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionHeader {
#[serde(rename = "type")]
pub type_: String,
pub version: u32,
pub id: String,
pub timestamp: String,
pub cwd: String,
pub parent_session: Option<String>,
}
impl SessionHeader {
pub fn new(id: String, timestamp: String, cwd: String, parent_session: Option<String>) -> Self {
Self {
type_: "session".to_owned(),
version: FORMAT_VERSION,
id,
timestamp,
cwd,
parent_session,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageEntry {
pub id: String,
pub parent_id: Option<String>,
pub timestamp: String,
pub message: opi_ai::message::Message,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompactionEntry {
pub id: String,
pub parent_id: Option<String>,
pub timestamp: String,
pub summary: String,
pub first_kept_entry_id: String,
pub tokens_before: u64,
pub tokens_after: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LeafEntry {
pub id: String,
pub parent_id: Option<String>,
pub timestamp: String,
pub entry_id: String,
}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum SessionEntry {
Message(MessageEntry),
Compaction(CompactionEntry),
Leaf(LeafEntry),
}
impl SessionEntry {
pub fn entry_id(&self) -> &str {
match self {
SessionEntry::Message(m) => &m.id,
SessionEntry::Compaction(c) => &c.id,
SessionEntry::Leaf(l) => &l.id,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CrashRecovery {
Clean,
TruncatedLine,
CorruptEntries {
count: usize,
},
CorruptEntriesWithTruncation {
count: usize,
},
}
impl CrashRecovery {
pub fn corrupt_count(&self) -> usize {
match self {
CrashRecovery::Clean | CrashRecovery::TruncatedLine => 0,
CrashRecovery::CorruptEntries { count }
| CrashRecovery::CorruptEntriesWithTruncation { count } => *count,
}
}
}
pub struct SessionWriter {
file: std::fs::File,
}
impl SessionWriter {
pub fn create(path: &Path, header: SessionHeader) -> std::io::Result<Self> {
let mut file = std::fs::File::create(path)?;
let header_json = serde_json::to_string(&header)?;
writeln!(file, "{header_json}")?;
file.sync_all()?;
Ok(Self { file })
}
pub fn open(path: &Path) -> std::io::Result<Self> {
use std::io::{Read, Seek, SeekFrom};
let mut file = std::fs::OpenOptions::new()
.read(true)
.write(true)
.open(path)?;
let len = file.seek(SeekFrom::End(0))?;
if len > 0 {
let mut last = [0u8; 1];
file.seek(SeekFrom::End(-1))?;
file.read_exact(&mut last)?;
if last[0] != b'\n' {
let mut pos = len;
let mut buf = [0u8; 1];
let mut found_newline = false;
loop {
if pos == 0 {
break;
}
pos -= 1;
file.seek(SeekFrom::Start(pos))?;
file.read_exact(&mut buf)?;
if buf[0] == b'\n' {
found_newline = true;
break;
}
}
file.set_len(if found_newline { pos + 1 } else { pos })?;
}
file.seek(SeekFrom::End(0))?;
}
Ok(Self { file })
}
pub fn append(&mut self, entry: &SessionEntry) -> std::io::Result<()> {
let json = serde_json::to_string(entry)?;
writeln!(self.file, "{json}")?;
self.file.sync_all()
}
}
pub struct SessionReader;
impl SessionReader {
pub fn read_all(path: &Path) -> std::io::Result<(SessionHeader, Vec<SessionEntry>)> {
let (header, entries, _recovery) = Self::read_with_recovery(path)?;
Ok((header, entries))
}
pub fn read_with_recovery(
path: &Path,
) -> std::io::Result<(SessionHeader, Vec<SessionEntry>, CrashRecovery)> {
let content = std::fs::read_to_string(path)?;
if content.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"empty session file",
));
}
let last_line_incomplete = !content.ends_with('\n') && !content.ends_with('\r');
let all_lines: Vec<&str> = content.lines().collect();
if all_lines.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"empty session file",
));
}
let header: SessionHeader = serde_json::from_str(all_lines[0]).map_err(|e| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("invalid session header: {e}"),
)
})?;
if header.type_ != "session" {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("expected header type 'session', got '{}'", header.type_),
));
}
if header.version != FORMAT_VERSION {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"unsupported session version {}, expected {}",
header.version, FORMAT_VERSION
),
));
}
let data_lines = &all_lines[1..];
let total = data_lines.len();
let mut entries = Vec::new();
let mut corrupt_count = 0;
for (i, line) in data_lines.iter().enumerate() {
if line.trim().is_empty() {
continue;
}
if last_line_incomplete && i == total - 1 {
continue;
}
match serde_json::from_str::<SessionEntry>(line) {
Ok(entry) => entries.push(entry),
Err(_) => corrupt_count += 1,
}
}
let recovery = match (corrupt_count > 0, last_line_incomplete) {
(true, true) => CrashRecovery::CorruptEntriesWithTruncation {
count: corrupt_count,
},
(true, false) => CrashRecovery::CorruptEntries {
count: corrupt_count,
},
(false, true) => CrashRecovery::TruncatedLine,
(false, false) => CrashRecovery::Clean,
};
Ok((header, entries, recovery))
}
}