cerememory_adapter_openai/
lib.rs1pub mod provider;
11pub use provider::OpenAIProvider;
12
13use cerememory_core::{
14 estimate_tokens_from_bytes, LLMAdapter, MemoryContent, MemoryRecord, ModelInfo,
15};
16use serde::Serialize;
17
18#[derive(Debug, Clone, Default)]
24pub struct OpenAIAdapter;
25
26impl OpenAIAdapter {
27 pub fn new() -> Self {
28 Self
29 }
30}
31
32#[derive(Serialize)]
34struct MemoryEntry {
35 id: String,
36 store: String,
37 fidelity: f64,
38 content: String,
39}
40
41#[derive(Serialize)]
43struct MemoriesPayload {
44 memories: Vec<MemoryEntry>,
45}
46
47impl LLMAdapter for OpenAIAdapter {
48 fn serialize_context(&self, memories: &[MemoryRecord], budget_tokens: usize) -> String {
49 let mut entries = Vec::new();
50 let mut accumulated_tokens = 0;
51
52 let envelope_overhead = r#"{"memories":[]}"#.len();
54 let envelope_tokens = estimate_tokens_from_bytes(envelope_overhead);
55 if envelope_tokens > budget_tokens {
56 return String::new();
57 }
58 accumulated_tokens += envelope_tokens;
59
60 for record in memories {
61 let text = match record.text_content() {
62 Some(t) => t,
63 None => continue,
64 };
65
66 let entry = MemoryEntry {
67 id: record.id.to_string(),
68 store: record.store.to_string(),
69 fidelity: record.fidelity.score,
70 content: text.to_string(),
71 };
72
73 let entry_json = serde_json::to_string(&entry).unwrap_or_default();
75 let entry_tokens = estimate_tokens_from_bytes(entry_json.len() + 1);
77
78 if accumulated_tokens + entry_tokens > budget_tokens {
79 break;
80 }
81
82 accumulated_tokens += entry_tokens;
83 entries.push(entry);
84 }
85
86 let payload = MemoriesPayload { memories: entries };
87 serde_json::to_string(&payload).unwrap_or_default()
88 }
89
90 fn estimate_tokens(&self, content: &MemoryContent) -> usize {
91 let total_bytes: usize = content.blocks.iter().map(|b| b.data.len()).sum();
92 let summary_bytes = content.summary.as_ref().map(|s| s.len()).unwrap_or(0);
93 estimate_tokens_from_bytes(total_bytes + summary_bytes)
94 }
95
96 fn model_info(&self) -> ModelInfo {
97 ModelInfo {
98 provider: "openai".to_string(),
99 model_name: "gpt-4o".to_string(),
100 max_context_tokens: 128_000,
101 }
102 }
103}
104
105#[cfg(test)]
106mod tests {
107 use super::*;
108 use cerememory_core::{MemoryRecord, StoreType};
109
110 fn make_adapter() -> OpenAIAdapter {
111 OpenAIAdapter::new()
112 }
113
114 fn make_record(text: &str) -> MemoryRecord {
115 MemoryRecord::new_text(StoreType::Semantic, text)
116 }
117
118 #[test]
119 fn serialize_context_produces_non_empty_output() {
120 let adapter = make_adapter();
121 let records = vec![make_record("Hello, world!")];
122 let output = adapter.serialize_context(&records, 1000);
123 assert!(!output.is_empty());
124
125 let parsed: serde_json::Value = serde_json::from_str(&output).unwrap();
127 let memories = parsed["memories"].as_array().unwrap();
128 assert_eq!(memories.len(), 1);
129 assert_eq!(memories[0]["content"], "Hello, world!");
130 }
131
132 #[test]
133 fn serialize_context_respects_token_budget() {
134 let adapter = make_adapter();
135 let records: Vec<MemoryRecord> = (0..100)
136 .map(|i| {
137 make_record(&format!(
138 "This is memory number {} with some longer content to consume tokens",
139 i
140 ))
141 })
142 .collect();
143
144 let small_output = adapter.serialize_context(&records, 50);
145 let large_output = adapter.serialize_context(&records, 100_000);
146
147 let small_parsed: serde_json::Value = serde_json::from_str(&small_output).unwrap();
149 let large_parsed: serde_json::Value = serde_json::from_str(&large_output).unwrap();
150
151 let small_count = small_parsed["memories"].as_array().unwrap().len();
152 let large_count = large_parsed["memories"].as_array().unwrap().len();
153
154 assert!(
155 small_count < large_count,
156 "Small budget should include fewer memories ({} vs {})",
157 small_count,
158 large_count
159 );
160 }
161
162 #[test]
163 fn serialize_context_zero_budget_produces_empty() {
164 let adapter = make_adapter();
165 let records = vec![make_record("Hello")];
166 let output = adapter.serialize_context(&records, 0);
167 assert!(output.is_empty());
168 }
169
170 #[test]
171 fn estimate_tokens_returns_reasonable_value() {
172 let adapter = make_adapter();
173 let record = make_record("Hello, world!");
174 let tokens = adapter.estimate_tokens(&record.content);
175 assert!(tokens > 0);
176 assert!(tokens <= 13);
177 }
178
179 #[test]
180 fn model_info_returns_correct_provider() {
181 let adapter = make_adapter();
182 let info = adapter.model_info();
183 assert_eq!(info.provider, "openai");
184 assert_eq!(info.model_name, "gpt-4o");
185 assert_eq!(info.max_context_tokens, 128_000);
186 }
187}