1use std::io::{Read, Write};
8
9use futures::TryStreamExt;
10use serde::{Deserialize, Serialize};
11
12use hirn_core::agent::AgentRecord;
13use hirn_core::episodic::EpisodicRecord;
14use hirn_core::namespace::NamespaceRecord;
15use hirn_core::procedural::ProceduralRecord;
16use hirn_core::semantic::SemanticRecord;
17use hirn_core::working::WorkingMemoryEntry;
18use hirn_core::{HirnError, HirnResult};
19
20use hirn_storage::PhysicalStore;
21use hirn_storage::datasets::{agent, episodic, graph, namespace, procedural, semantic, working};
22use hirn_storage::store::ScanOptions;
23
24use crate::graph::GraphEdge;
25
26#[derive(Debug, Serialize, Deserialize)]
28pub struct ExportData {
29 pub version: u32,
30 pub working: Vec<WorkingMemoryEntry>,
31 pub episodic: Vec<EpisodicRecord>,
32 pub semantic: Vec<SemanticRecord>,
33 #[serde(default)]
34 pub procedural: Vec<ProceduralRecord>,
35 pub agents: Vec<AgentRecord>,
36 pub namespaces: Vec<NamespaceRecord>,
37 #[serde(default)]
39 pub edges: Vec<GraphEdge>,
40}
41
42#[derive(Debug)]
44pub struct ExportReport {
45 pub working_count: u64,
46 pub episodic_count: u64,
47 pub semantic_count: u64,
48 pub procedural_count: u64,
49 pub agent_count: u64,
50 pub namespace_count: u64,
51 pub edge_count: u64,
52 pub bytes_written: u64,
53}
54
55#[derive(Debug)]
57pub struct ImportReport {
58 pub working_count: u64,
59 pub episodic_count: u64,
60 pub semantic_count: u64,
61 pub procedural_count: u64,
62 pub agent_count: u64,
63 pub namespace_count: u64,
64 pub edge_count: u64,
65}
66
67#[allow(clippy::future_not_send)]
69pub async fn export(
70 storage: &dyn PhysicalStore,
71 writer: &mut dyn Write,
72) -> HirnResult<ExportReport> {
73 let scan_opts = ScanOptions {
74 columns: None,
75 filter: None,
76 exact_filter: None,
77 order_by: None,
78 limit: None,
79 offset: None,
80 };
81
82 let working = scan_dataset(storage, working::DATASET_NAME, &scan_opts, |b| {
83 working::from_batch(b).map_err(|e| HirnError::storage(e))
84 })
85 .await?;
86
87 let episodic = scan_dataset(storage, episodic::DATASET_NAME, &scan_opts, |b| {
88 episodic::from_batch(b).map_err(|e| HirnError::storage(e))
89 })
90 .await?;
91
92 let semantic = scan_dataset(storage, semantic::DATASET_NAME, &scan_opts, |b| {
93 semantic::from_batch(b).map_err(|e| HirnError::storage(e))
94 })
95 .await?;
96
97 let procedural = scan_dataset(storage, procedural::DATASET_NAME, &scan_opts, |b| {
98 procedural::from_batch(b).map_err(|e| HirnError::storage(e))
99 })
100 .await?;
101
102 let agents = scan_dataset(storage, agent::DATASET_NAME, &scan_opts, |b| {
103 agent::from_batch(b).map_err(|e| HirnError::storage(e))
104 })
105 .await?;
106
107 let namespaces = scan_dataset(storage, namespace::DATASET_NAME, &scan_opts, |b| {
108 namespace::from_batch(b).map_err(|e| HirnError::storage(e))
109 })
110 .await?;
111
112 let edges = scan_dataset(storage, graph::DATASET_EDGES_NAME, &scan_opts, |b| {
113 graph::edges_from_batch(b).map_err(|e| HirnError::storage(e))
114 })
115 .await?;
116
117 let data = ExportData {
118 version: 1,
119 working,
120 episodic,
121 semantic,
122 procedural,
123 agents,
124 namespaces,
125 edges,
126 };
127
128 let json = serde_json::to_string_pretty(&data)
129 .map_err(|e| HirnError::storage(format!("json serialization: {e}")))?;
130
131 writer
132 .write_all(json.as_bytes())
133 .map_err(|e| HirnError::storage(format!("write: {e}")))?;
134
135 Ok(ExportReport {
136 working_count: data.working.len() as u64,
137 episodic_count: data.episodic.len() as u64,
138 semantic_count: data.semantic.len() as u64,
139 procedural_count: data.procedural.len() as u64,
140 agent_count: data.agents.len() as u64,
141 namespace_count: data.namespaces.len() as u64,
142 edge_count: data.edges.len() as u64,
143 bytes_written: json.len() as u64,
144 })
145}
146
147#[allow(clippy::future_not_send)]
149pub async fn import(
150 reader: &mut dyn Read,
151 storage: &dyn PhysicalStore,
152 embedding_dims: usize,
153) -> HirnResult<ImportReport> {
154 let mut json = String::new();
155 reader
156 .read_to_string(&mut json)
157 .map_err(|e| HirnError::storage(format!("read: {e}")))?;
158
159 let data: ExportData =
160 serde_json::from_str(&json).map_err(|e| HirnError::storage(format!("json parse: {e}")))?;
161
162 if !data.working.is_empty() {
163 let batch = working::to_batch(&data.working).map_err(|e| HirnError::storage(e))?;
164 storage
165 .append(working::DATASET_NAME, batch)
166 .await
167 .map_err(|e| HirnError::storage(e))?;
168 }
169
170 if !data.episodic.is_empty() {
171 let batch = episodic::to_batch(&data.episodic, embedding_dims)
172 .map_err(|e| HirnError::storage(e))?;
173 storage
174 .append(episodic::DATASET_NAME, batch)
175 .await
176 .map_err(|e| HirnError::storage(e))?;
177 }
178
179 if !data.semantic.is_empty() {
180 let batch = semantic::to_batch(&data.semantic, embedding_dims)
181 .map_err(|e| HirnError::storage(e))?;
182 storage
183 .append(semantic::DATASET_NAME, batch)
184 .await
185 .map_err(|e| HirnError::storage(e))?;
186 }
187
188 if !data.procedural.is_empty() {
189 let batch = procedural::to_batch(&data.procedural, embedding_dims)
190 .map_err(|e| HirnError::storage(e))?;
191 storage
192 .append(procedural::DATASET_NAME, batch)
193 .await
194 .map_err(|e| HirnError::storage(e))?;
195 }
196
197 if !data.agents.is_empty() {
198 let batch = agent::to_batch(&data.agents).map_err(|e| HirnError::storage(e))?;
199 storage
200 .append(agent::DATASET_NAME, batch)
201 .await
202 .map_err(|e| HirnError::storage(e))?;
203 }
204
205 if !data.namespaces.is_empty() {
206 let batch = namespace::to_batch(&data.namespaces).map_err(|e| HirnError::storage(e))?;
207 storage
208 .append(namespace::DATASET_NAME, batch)
209 .await
210 .map_err(|e| HirnError::storage(e))?;
211 }
212
213 if !data.edges.is_empty() {
214 let batch = graph::edges_to_batch(&data.edges).map_err(|e| HirnError::storage(e))?;
215 storage
216 .append(graph::DATASET_EDGES_NAME, batch)
217 .await
218 .map_err(|e| HirnError::storage(e))?;
219 }
220
221 Ok(ImportReport {
222 working_count: data.working.len() as u64,
223 episodic_count: data.episodic.len() as u64,
224 semantic_count: data.semantic.len() as u64,
225 procedural_count: data.procedural.len() as u64,
226 agent_count: data.agents.len() as u64,
227 namespace_count: data.namespaces.len() as u64,
228 edge_count: data.edges.len() as u64,
229 })
230}
231
232async fn scan_dataset<T>(
237 storage: &dyn PhysicalStore,
238 dataset: &str,
239 opts: &ScanOptions,
240 convert: impl Fn(&arrow_array::RecordBatch) -> HirnResult<Vec<T>>,
241) -> HirnResult<Vec<T>> {
242 let mut batches = match storage.scan_stream(dataset, opts.clone()).await {
243 Ok(b) => b,
244 Err(_) => return Ok(Vec::new()),
246 };
247
248 let mut out = Vec::new();
249 while let Some(batch) = batches.try_next().await? {
250 let records = convert(&batch)?;
251 out.extend(records);
252 }
253 Ok(out)
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use hirn_storage::memory_store::MemoryStore;
260
261 #[tokio::test]
262 async fn export_empty_storage_produces_valid_json() {
263 let storage = MemoryStore::new();
264 let mut buf = Vec::new();
265 let report = export(&storage, &mut buf).await.unwrap();
266
267 assert_eq!(report.episodic_count, 0);
268 assert_eq!(report.semantic_count, 0);
269 assert_eq!(report.working_count, 0);
270 assert_eq!(report.bytes_written as usize, buf.len());
271
272 let data: ExportData = serde_json::from_slice(&buf).unwrap();
273 assert_eq!(data.version, 1);
274 }
275
276 #[tokio::test]
277 async fn import_empty_json() {
278 let storage = MemoryStore::new();
279 let json = serde_json::to_string(&ExportData {
280 version: 1,
281 working: vec![],
282 episodic: vec![],
283 semantic: vec![],
284 procedural: vec![],
285 agents: vec![],
286 namespaces: vec![],
287 edges: vec![],
288 })
289 .unwrap();
290 let report = import(&mut json.as_bytes(), &storage, 768).await.unwrap();
291 assert_eq!(report.episodic_count, 0);
292 }
293
294 #[tokio::test]
295 async fn import_invalid_json_returns_error() {
296 let storage = MemoryStore::new();
297 let bad_json = b"{ not valid json";
298 let result = import(&mut bad_json.as_slice(), &storage, 768).await;
299 assert!(result.is_err());
300 }
301}