Skip to main content

kaizen/store/
cold_parquet.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2//! COLD Parquet partitions for event-shaped rows.
3
4use crate::core::event::{Event, EventKind, EventSource};
5use anyhow::Result;
6use arrow::array::{
7    Array, ArrayRef, BooleanArray, Int64Array, StringArray, UInt16Array, UInt32Array, UInt64Array,
8};
9use arrow::record_batch::RecordBatch;
10use parquet::arrow::ArrowWriter;
11use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
12use parquet::basic::{Compression, ZstdLevel};
13use parquet::file::metadata::KeyValue;
14use parquet::file::properties::WriterProperties;
15use std::collections::BTreeMap;
16use std::fs::File;
17use std::path::{Path, PathBuf};
18use std::sync::Arc;
19use time::{OffsetDateTime, format_description::well_known::Iso8601};
20
21pub const SCHEMA_VERSION: &str = "1";
22
23pub struct DailyEventWriter {
24    root: PathBuf,
25    max_rows: usize,
26    groups: BTreeMap<String, Vec<Event>>,
27    next_chunk: BTreeMap<String, u64>,
28    paths: Vec<PathBuf>,
29}
30
31impl DailyEventWriter {
32    pub fn new(root: &Path, max_rows: usize) -> Self {
33        Self {
34            root: root.to_path_buf(),
35            max_rows: max_rows.max(1),
36            groups: BTreeMap::new(),
37            next_chunk: BTreeMap::new(),
38            paths: Vec::new(),
39        }
40    }
41
42    pub fn push(&mut self, event: Event) -> Result<()> {
43        let day = partition_day(event.ts_ms)?;
44        let full = push_group(&mut self.groups, day.clone(), event, self.max_rows);
45        if full {
46            self.flush_day(&day)?;
47        }
48        Ok(())
49    }
50
51    pub fn finish(mut self) -> Result<Vec<PathBuf>> {
52        while let Some(day) = self.groups.keys().next().cloned() {
53            self.flush_day(&day)?;
54        }
55        Ok(self.paths)
56    }
57
58    fn flush_day(&mut self, day: &str) -> Result<()> {
59        let Some(rows) = self.groups.remove(day) else {
60            return Ok(());
61        };
62        let path = self.next_chunk_path(day)?;
63        write_batch(&path, &rows)?;
64        self.paths.push(path);
65        Ok(())
66    }
67
68    fn next_chunk_path(&mut self, day: &str) -> Result<PathBuf> {
69        let dir = self.root.join("cold/events");
70        std::fs::create_dir_all(&dir)?;
71        let n = self.next_chunk.entry(day.to_string()).or_default();
72        let path = dir.join(format!("{day}-{n:06}.parquet"));
73        *n += 1;
74        Ok(path)
75    }
76}
77
78fn push_group(
79    groups: &mut BTreeMap<String, Vec<Event>>,
80    day: String,
81    event: Event,
82    max_rows: usize,
83) -> bool {
84    let rows = groups.entry(day).or_default();
85    rows.push(event);
86    rows.len() >= max_rows
87}
88
89pub fn write_daily_events(root: &Path, events: &[Event]) -> Result<Vec<PathBuf>> {
90    let mut groups: BTreeMap<String, Vec<Event>> = BTreeMap::new();
91    for event in events {
92        groups
93            .entry(partition_day(event.ts_ms)?)
94            .or_default()
95            .push(event.clone());
96    }
97    write_daily_event_groups(root, groups)
98}
99
100pub fn write_daily_event_groups(
101    root: &Path,
102    groups: BTreeMap<String, Vec<Event>>,
103) -> Result<Vec<PathBuf>> {
104    let mut paths = Vec::new();
105    for (day, rows) in groups {
106        let dir = root.join("cold/events");
107        std::fs::create_dir_all(&dir)?;
108        let path = dir.join(format!("{day}.parquet"));
109        write_batch(&path, &rows)?;
110        paths.push(path);
111    }
112    Ok(paths)
113}
114
115pub fn read_events_dir(root: &Path) -> Result<Vec<Event>> {
116    let dir = root.join("cold/events");
117    if !dir.exists() {
118        return Ok(Vec::new());
119    }
120    let mut out = Vec::new();
121    for entry in std::fs::read_dir(dir)? {
122        let path = entry?.path();
123        if path.extension().and_then(|s| s.to_str()) == Some("parquet") {
124            out.extend(read_events_file(&path)?);
125        }
126    }
127    out.sort_by(|a, b| (a.ts_ms, &a.session_id, a.seq).cmp(&(b.ts_ms, &b.session_id, b.seq)));
128    Ok(out)
129}
130
131pub fn remove_partitions_older_than(root: &Path, cutoff_ms: u64) -> Result<u64> {
132    let dir = root.join("cold/events");
133    if !dir.exists() {
134        return Ok(0);
135    }
136    let cutoff = partition_day(cutoff_ms)?;
137    let mut removed = 0;
138    for entry in std::fs::read_dir(dir)? {
139        let path = entry?.path();
140        let Some(stem) = path.file_stem().and_then(|s| s.to_str()) else {
141            continue;
142        };
143        if stem < cutoff.as_str() {
144            std::fs::remove_file(&path)?;
145            removed += 1;
146        }
147    }
148    Ok(removed)
149}
150
151fn write_batch(path: &Path, events: &[Event]) -> Result<()> {
152    let batch = batch_from_events(events)?;
153    let file = File::create(path)?;
154    let props = WriterProperties::builder()
155        .set_compression(Compression::ZSTD(ZstdLevel::default()))
156        .set_key_value_metadata(Some(vec![KeyValue {
157            key: "kaizen_schema_v".into(),
158            value: Some(SCHEMA_VERSION.into()),
159        }]))
160        .build();
161    let mut writer = ArrowWriter::try_new(file, batch.schema(), Some(props))?;
162    writer.write(&batch)?;
163    writer.close()?;
164    Ok(())
165}
166
167fn read_events_file(path: &Path) -> Result<Vec<Event>> {
168    let file = File::open(path)?;
169    let reader = ParquetRecordBatchReaderBuilder::try_new(file)?.build()?;
170    let mut out = Vec::new();
171    for batch in reader {
172        out.extend(events_from_batch(&batch?)?);
173    }
174    Ok(out)
175}
176
177fn batch_from_events(events: &[Event]) -> Result<RecordBatch> {
178    let payload = events
179        .iter()
180        .map(|e| serde_json::to_string(&e.payload).unwrap_or_else(|_| "null".into()))
181        .collect::<Vec<_>>();
182    let cols: Vec<(&str, ArrayRef)> = vec![
183        (
184            "session_id",
185            Arc::new(StringArray::from(strs(events, |e| &e.session_id))),
186        ),
187        ("seq", Arc::new(UInt64Array::from(vals(events, |e| e.seq)))),
188        (
189            "ts_ms",
190            Arc::new(UInt64Array::from(vals(events, |e| e.ts_ms))),
191        ),
192        (
193            "ts_exact",
194            Arc::new(BooleanArray::from(vals(events, |e| e.ts_exact))),
195        ),
196        (
197            "kind",
198            Arc::new(StringArray::from(strs(events, |e| kind(&e.kind)))),
199        ),
200        (
201            "source",
202            Arc::new(StringArray::from(strs(events, |e| source(&e.source)))),
203        ),
204        (
205            "tool",
206            Arc::new(StringArray::from(opts(events, |e| e.tool.clone()))),
207        ),
208        (
209            "tool_call_id",
210            Arc::new(StringArray::from(opts(events, |e| e.tool_call_id.clone()))),
211        ),
212        (
213            "tokens_in",
214            Arc::new(UInt32Array::from(opt_u32(events, |e| e.tokens_in))),
215        ),
216        (
217            "tokens_out",
218            Arc::new(UInt32Array::from(opt_u32(events, |e| e.tokens_out))),
219        ),
220        (
221            "reasoning_tokens",
222            Arc::new(UInt32Array::from(opt_u32(events, |e| e.reasoning_tokens))),
223        ),
224        (
225            "cost_usd_e6",
226            Arc::new(Int64Array::from(opt_i64(events, |e| e.cost_usd_e6))),
227        ),
228        ("payload", Arc::new(StringArray::from(payload))),
229        (
230            "stop_reason",
231            Arc::new(StringArray::from(opts(events, |e| e.stop_reason.clone()))),
232        ),
233        (
234            "latency_ms",
235            Arc::new(UInt32Array::from(opt_u32(events, |e| e.latency_ms))),
236        ),
237        (
238            "ttft_ms",
239            Arc::new(UInt32Array::from(opt_u32(events, |e| e.ttft_ms))),
240        ),
241        (
242            "retry_count",
243            Arc::new(UInt16Array::from(opt_u16(events, |e| e.retry_count))),
244        ),
245        (
246            "context_used_tokens",
247            Arc::new(UInt32Array::from(opt_u32(events, |e| {
248                e.context_used_tokens
249            }))),
250        ),
251        (
252            "context_max_tokens",
253            Arc::new(UInt32Array::from(opt_u32(events, |e| e.context_max_tokens))),
254        ),
255        (
256            "cache_creation_tokens",
257            Arc::new(UInt32Array::from(opt_u32(events, |e| {
258                e.cache_creation_tokens
259            }))),
260        ),
261        (
262            "cache_read_tokens",
263            Arc::new(UInt32Array::from(opt_u32(events, |e| e.cache_read_tokens))),
264        ),
265        (
266            "system_prompt_tokens",
267            Arc::new(UInt32Array::from(opt_u32(events, |e| {
268                e.system_prompt_tokens
269            }))),
270        ),
271    ];
272    Ok(RecordBatch::try_from_iter(cols)?)
273}
274
275fn events_from_batch(batch: &RecordBatch) -> Result<Vec<Event>> {
276    let s = |i| str_col(batch, i);
277    let u64c = |i| u64_col(batch, i);
278    let out = (0..batch.num_rows())
279        .map(|i| Event {
280            session_id: s(0).value(i).into(),
281            seq: u64c(1).value(i),
282            ts_ms: u64c(2).value(i),
283            ts_exact: bool_col(batch, 3).value(i),
284            kind: kind_from(s(4).value(i)),
285            source: source_from(s(5).value(i)),
286            tool: opt_str(batch, 6, i),
287            tool_call_id: opt_str(batch, 7, i),
288            tokens_in: opt_u32_at(batch, 8, i),
289            tokens_out: opt_u32_at(batch, 9, i),
290            reasoning_tokens: opt_u32_at(batch, 10, i),
291            cost_usd_e6: opt_i64_at(batch, 11, i),
292            payload: serde_json::from_str(s(12).value(i)).unwrap_or(serde_json::Value::Null),
293            stop_reason: opt_str(batch, 13, i),
294            latency_ms: opt_u32_at(batch, 14, i),
295            ttft_ms: opt_u32_at(batch, 15, i),
296            retry_count: opt_u16_at(batch, 16, i),
297            context_used_tokens: opt_u32_at(batch, 17, i),
298            context_max_tokens: opt_u32_at(batch, 18, i),
299            cache_creation_tokens: opt_u32_at(batch, 19, i),
300            cache_read_tokens: opt_u32_at(batch, 20, i),
301            system_prompt_tokens: opt_u32_at(batch, 21, i),
302        })
303        .collect();
304    Ok(out)
305}
306
307pub fn partition_day(ts_ms: u64) -> Result<String> {
308    let ts = OffsetDateTime::from_unix_timestamp((ts_ms / 1000) as i64)?;
309    let text = ts.date().format(&Iso8601::DATE)?;
310    Ok(text)
311}
312
313fn vals<T: Copy>(events: &[Event], f: impl Fn(&Event) -> T) -> Vec<T> {
314    events.iter().map(f).collect()
315}
316
317fn strs(events: &[Event], f: impl Fn(&Event) -> &str) -> Vec<String> {
318    events.iter().map(|e| f(e).to_string()).collect()
319}
320
321fn opts(events: &[Event], f: impl Fn(&Event) -> Option<String>) -> Vec<Option<String>> {
322    events.iter().map(f).collect()
323}
324
325fn opt_u32(events: &[Event], f: impl Fn(&Event) -> Option<u32>) -> Vec<Option<u32>> {
326    events.iter().map(f).collect()
327}
328
329fn opt_u16(events: &[Event], f: impl Fn(&Event) -> Option<u16>) -> Vec<Option<u16>> {
330    events.iter().map(f).collect()
331}
332
333fn opt_i64(events: &[Event], f: impl Fn(&Event) -> Option<i64>) -> Vec<Option<i64>> {
334    events.iter().map(f).collect()
335}
336
337fn str_col(batch: &RecordBatch, i: usize) -> &StringArray {
338    batch.column(i).as_any().downcast_ref().unwrap()
339}
340
341fn u64_col(batch: &RecordBatch, i: usize) -> &UInt64Array {
342    batch.column(i).as_any().downcast_ref().unwrap()
343}
344
345fn bool_col(batch: &RecordBatch, i: usize) -> &BooleanArray {
346    batch.column(i).as_any().downcast_ref().unwrap()
347}
348
349fn opt_str(batch: &RecordBatch, col: usize, row: usize) -> Option<String> {
350    let a = str_col(batch, col);
351    (!a.is_null(row)).then(|| a.value(row).to_string())
352}
353
354fn opt_u32_at(batch: &RecordBatch, col: usize, row: usize) -> Option<u32> {
355    let a = batch
356        .column(col)
357        .as_any()
358        .downcast_ref::<UInt32Array>()
359        .unwrap();
360    (!a.is_null(row)).then(|| a.value(row))
361}
362
363fn opt_u16_at(batch: &RecordBatch, col: usize, row: usize) -> Option<u16> {
364    let a = batch
365        .column(col)
366        .as_any()
367        .downcast_ref::<UInt16Array>()
368        .unwrap();
369    (!a.is_null(row)).then(|| a.value(row))
370}
371
372fn opt_i64_at(batch: &RecordBatch, col: usize, row: usize) -> Option<i64> {
373    let a = batch
374        .column(col)
375        .as_any()
376        .downcast_ref::<Int64Array>()
377        .unwrap();
378    (!a.is_null(row)).then(|| a.value(row))
379}
380
381fn kind(kind: &EventKind) -> &'static str {
382    match kind {
383        EventKind::ToolCall => "ToolCall",
384        EventKind::ToolResult => "ToolResult",
385        EventKind::Message => "Message",
386        EventKind::Error => "Error",
387        EventKind::Cost => "Cost",
388        EventKind::Hook => "Hook",
389        EventKind::Lifecycle => "Lifecycle",
390    }
391}
392
393fn source(source: &EventSource) -> &'static str {
394    match source {
395        EventSource::Tail => "Tail",
396        EventSource::Hook => "Hook",
397        EventSource::Proxy => "Proxy",
398    }
399}
400
401fn kind_from(s: &str) -> EventKind {
402    match s {
403        "ToolCall" => EventKind::ToolCall,
404        "ToolResult" => EventKind::ToolResult,
405        "Error" => EventKind::Error,
406        "Cost" => EventKind::Cost,
407        "Hook" => EventKind::Hook,
408        "Lifecycle" => EventKind::Lifecycle,
409        _ => EventKind::Message,
410    }
411}
412
413fn source_from(s: &str) -> EventSource {
414    match s {
415        "Hook" => EventSource::Hook,
416        "Proxy" => EventSource::Proxy,
417        _ => EventSource::Tail,
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use serde_json::json;
425
426    #[test]
427    fn write_and_read_daily_events() {
428        let dir = tempfile::tempdir().unwrap();
429        let event = Event {
430            session_id: "s1".into(),
431            seq: 0,
432            ts_ms: 1_700_000_000_000,
433            ts_exact: true,
434            kind: EventKind::Message,
435            source: EventSource::Hook,
436            tool: None,
437            tool_call_id: None,
438            tokens_in: Some(10),
439            tokens_out: None,
440            reasoning_tokens: None,
441            cost_usd_e6: Some(5),
442            stop_reason: None,
443            latency_ms: None,
444            ttft_ms: None,
445            retry_count: None,
446            context_used_tokens: Some(12),
447            context_max_tokens: None,
448            cache_creation_tokens: None,
449            cache_read_tokens: None,
450            system_prompt_tokens: None,
451            payload: json!({"type": "note"}),
452        };
453        let paths = write_daily_events(dir.path(), std::slice::from_ref(&event)).unwrap();
454        assert_eq!(paths.len(), 1);
455        let rows = read_events_dir(dir.path()).unwrap();
456        assert_eq!(rows.len(), 1);
457        assert_eq!(rows[0].session_id, event.session_id);
458        assert_eq!(rows[0].cost_usd_e6, Some(5));
459        assert_eq!(rows[0].context_used_tokens, Some(12));
460    }
461}