Skip to main content

cerememory_adapter_openai/
lib.rs

1//! Cerememory LLM adapter for OpenAI.
2//!
3//! Implements the `LLMAdapter` trait for OpenAI models.
4//! Uses JSON system message format for memory serialization.
5//!
6//! Also provides [`OpenAIProvider`], an [`LLMProvider`](cerememory_core::LLMProvider)
7//! implementation that calls the OpenAI embeddings and chat completions APIs
8//! with exponential-backoff retry on transient errors.
9
10pub mod provider;
11pub use provider::OpenAIProvider;
12
13use cerememory_core::{
14    estimate_tokens_from_bytes, LLMAdapter, MemoryContent, MemoryRecord, ModelInfo,
15};
16use serde::Serialize;
17
18/// LLM adapter for OpenAI models.
19///
20/// Serializes memories as a JSON object with a `memories` array,
21/// suitable for injection into OpenAI system messages.
22/// Phase 1: text-only stub with byte-based token estimation.
23#[derive(Debug, Clone, Default)]
24pub struct OpenAIAdapter;
25
26impl OpenAIAdapter {
27    pub fn new() -> Self {
28        Self
29    }
30}
31
32/// Internal struct for JSON serialization of a single memory entry.
33#[derive(Serialize)]
34struct MemoryEntry {
35    id: String,
36    store: String,
37    fidelity: f64,
38    content: String,
39}
40
41/// Wrapper for the memories array.
42#[derive(Serialize)]
43struct MemoriesPayload {
44    memories: Vec<MemoryEntry>,
45}
46
47impl LLMAdapter for OpenAIAdapter {
48    fn serialize_context(&self, memories: &[MemoryRecord], budget_tokens: usize) -> String {
49        let mut entries = Vec::new();
50        let mut accumulated_tokens = 0;
51
52        // Account for the JSON envelope overhead: {"memories":[]}
53        let envelope_overhead = r#"{"memories":[]}"#.len();
54        let envelope_tokens = estimate_tokens_from_bytes(envelope_overhead);
55        if envelope_tokens > budget_tokens {
56            return String::new();
57        }
58        accumulated_tokens += envelope_tokens;
59
60        for record in memories {
61            let text = match record.text_content() {
62                Some(t) => t,
63                None => continue,
64            };
65
66            let entry = MemoryEntry {
67                id: record.id.to_string(),
68                store: record.store.to_string(),
69                fidelity: record.fidelity.score,
70                content: text.to_string(),
71            };
72
73            // Estimate the token cost of this entry (including JSON overhead like commas, braces)
74            let entry_json = serde_json::to_string(&entry).unwrap_or_default();
75            // Add 1 for the comma separator between entries
76            let entry_tokens = estimate_tokens_from_bytes(entry_json.len() + 1);
77
78            if accumulated_tokens + entry_tokens > budget_tokens {
79                break;
80            }
81
82            accumulated_tokens += entry_tokens;
83            entries.push(entry);
84        }
85
86        let payload = MemoriesPayload { memories: entries };
87        serde_json::to_string(&payload).unwrap_or_default()
88    }
89
90    fn estimate_tokens(&self, content: &MemoryContent) -> usize {
91        let total_bytes: usize = content.blocks.iter().map(|b| b.data.len()).sum();
92        let summary_bytes = content.summary.as_ref().map(|s| s.len()).unwrap_or(0);
93        estimate_tokens_from_bytes(total_bytes + summary_bytes)
94    }
95
96    fn model_info(&self) -> ModelInfo {
97        ModelInfo {
98            provider: "openai".to_string(),
99            model_name: "gpt-4o".to_string(),
100            max_context_tokens: 128_000,
101        }
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use cerememory_core::{MemoryRecord, StoreType};
109
110    fn make_adapter() -> OpenAIAdapter {
111        OpenAIAdapter::new()
112    }
113
114    fn make_record(text: &str) -> MemoryRecord {
115        MemoryRecord::new_text(StoreType::Semantic, text)
116    }
117
118    #[test]
119    fn serialize_context_produces_non_empty_output() {
120        let adapter = make_adapter();
121        let records = vec![make_record("Hello, world!")];
122        let output = adapter.serialize_context(&records, 1000);
123        assert!(!output.is_empty());
124
125        // Validate it's valid JSON
126        let parsed: serde_json::Value = serde_json::from_str(&output).unwrap();
127        let memories = parsed["memories"].as_array().unwrap();
128        assert_eq!(memories.len(), 1);
129        assert_eq!(memories[0]["content"], "Hello, world!");
130    }
131
132    #[test]
133    fn serialize_context_respects_token_budget() {
134        let adapter = make_adapter();
135        let records: Vec<MemoryRecord> = (0..100)
136            .map(|i| {
137                make_record(&format!(
138                    "This is memory number {} with some longer content to consume tokens",
139                    i
140                ))
141            })
142            .collect();
143
144        let small_output = adapter.serialize_context(&records, 50);
145        let large_output = adapter.serialize_context(&records, 100_000);
146
147        // Both should be valid JSON
148        let small_parsed: serde_json::Value = serde_json::from_str(&small_output).unwrap();
149        let large_parsed: serde_json::Value = serde_json::from_str(&large_output).unwrap();
150
151        let small_count = small_parsed["memories"].as_array().unwrap().len();
152        let large_count = large_parsed["memories"].as_array().unwrap().len();
153
154        assert!(
155            small_count < large_count,
156            "Small budget should include fewer memories ({} vs {})",
157            small_count,
158            large_count
159        );
160    }
161
162    #[test]
163    fn serialize_context_zero_budget_produces_empty() {
164        let adapter = make_adapter();
165        let records = vec![make_record("Hello")];
166        let output = adapter.serialize_context(&records, 0);
167        assert!(output.is_empty());
168    }
169
170    #[test]
171    fn estimate_tokens_returns_reasonable_value() {
172        let adapter = make_adapter();
173        let record = make_record("Hello, world!");
174        let tokens = adapter.estimate_tokens(&record.content);
175        assert!(tokens > 0);
176        assert!(tokens <= 13);
177    }
178
179    #[test]
180    fn model_info_returns_correct_provider() {
181        let adapter = make_adapter();
182        let info = adapter.model_info();
183        assert_eq!(info.provider, "openai");
184        assert_eq!(info.model_name, "gpt-4o");
185        assert_eq!(info.max_context_tokens, 128_000);
186    }
187}