1use crate::cell::Cell;
2use crate::error::NotebookError;
3use crate::timing::{CellTiming, TimingSummary};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fs;
7use std::path::Path;
8use std::str::FromStr;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Notebook {
13 pub cells: Vec<Cell>,
15 pub metadata: NotebookMetadata,
17 pub nbformat: u32,
19 pub nbformat_minor: u32,
21}
22
23impl Notebook {
24 pub fn from_file(path: &Path) -> Result<Self, NotebookError> {
26 let contents = fs::read_to_string(path)?;
27 contents.parse()
28 }
29
30 pub fn to_file(&self, path: &Path) -> Result<(), NotebookError> {
32 let contents = serde_json::to_string_pretty(self)?;
33 fs::write(path, contents)?;
34 Ok(())
35 }
36
37 pub fn to_string(&self) -> Result<String, NotebookError> {
39 Ok(serde_json::to_string_pretty(self)?)
40 }
41
42 pub fn get_cell(&self, index: usize) -> Result<&Cell, NotebookError> {
44 self.cells
45 .get(index)
46 .ok_or(NotebookError::CellIndexOutOfBounds(index, self.cells.len()))
47 }
48
49 pub fn get_cell_mut(&mut self, index: usize) -> Result<&mut Cell, NotebookError> {
51 let len = self.cells.len();
52 self.cells
53 .get_mut(index)
54 .ok_or(NotebookError::CellIndexOutOfBounds(index, len))
55 }
56
57 pub fn get_cells(&self, indices: &[usize]) -> Result<Vec<&Cell>, NotebookError> {
59 indices.iter().map(|&i| self.get_cell(i)).collect()
60 }
61
62 pub fn cell_count(&self) -> usize {
64 self.cells.len()
65 }
66
67 pub fn code_cell_count(&self) -> usize {
69 self.cells.iter().filter(|c| c.is_code()).count()
70 }
71
72 pub fn cell_timings(&self) -> Vec<CellTiming> {
77 self.cells
78 .iter()
79 .enumerate()
80 .filter_map(|(i, cell)| {
81 if let Cell::Code(code_cell) = cell {
82 let timing = code_cell.timing()?;
83 Some(CellTiming::new(i, code_cell.execution_count, timing))
84 } else {
85 None
86 }
87 })
88 .collect()
89 }
90
91 pub fn timing_summary(&self) -> Option<TimingSummary> {
96 let timings = self.cell_timings();
97 TimingSummary::from_cells(&timings)
98 }
99
100 pub fn cells_with_timing_count(&self) -> usize {
102 self.cells
103 .iter()
104 .filter(|cell| {
105 if let Cell::Code(code_cell) = cell {
106 code_cell.has_timing()
107 } else {
108 false
109 }
110 })
111 .count()
112 }
113}
114
115impl FromStr for Notebook {
116 type Err = NotebookError;
117
118 fn from_str(s: &str) -> Result<Self, Self::Err> {
119 let notebook: Notebook = serde_json::from_str(s)?;
120
121 if notebook.nbformat < 4 {
123 return Err(NotebookError::UnsupportedVersion(
124 notebook.nbformat,
125 notebook.nbformat_minor,
126 ));
127 }
128
129 Ok(notebook)
130 }
131}
132
133#[derive(Debug, Clone, Default, Serialize, Deserialize)]
135pub struct NotebookMetadata {
136 #[serde(skip_serializing_if = "Option::is_none")]
138 pub kernelspec: Option<KernelSpec>,
139 #[serde(skip_serializing_if = "Option::is_none")]
141 pub language_info: Option<LanguageInfo>,
142 #[serde(flatten)]
144 pub extra: HashMap<String, serde_json::Value>,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct KernelSpec {
149 pub name: String,
150 pub display_name: String,
151 #[serde(skip_serializing_if = "Option::is_none")]
152 pub language: Option<String>,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct LanguageInfo {
157 pub name: String,
158 #[serde(skip_serializing_if = "Option::is_none")]
159 pub version: Option<String>,
160 #[serde(skip_serializing_if = "Option::is_none")]
161 pub file_extension: Option<String>,
162 #[serde(flatten)]
163 pub extra: HashMap<String, serde_json::Value>,
164}
165
166#[cfg(test)]
167mod tests {
168 use super::*;
169
170 fn create_minimal_notebook() -> Notebook {
171 Notebook {
172 cells: vec![],
173 metadata: NotebookMetadata::default(),
174 nbformat: 4,
175 nbformat_minor: 5,
176 }
177 }
178
179 fn create_test_notebook_json() -> &'static str {
180 r#"{
181 "cells": [
182 {
183 "cell_type": "code",
184 "execution_count": null,
185 "id": "test-1",
186 "metadata": {},
187 "outputs": [],
188 "source": ["print('hello')"]
189 }
190 ],
191 "metadata": {},
192 "nbformat": 4,
193 "nbformat_minor": 5
194 }"#
195 }
196
197 #[test]
198 fn test_parse_valid_notebook() -> Result<(), Box<dyn std::error::Error>> {
199 let json = create_test_notebook_json();
200 let notebook: Notebook = json.parse()?;
201
202 assert_eq!(notebook.nbformat, 4);
203 assert_eq!(notebook.nbformat_minor, 5);
204 assert_eq!(notebook.cells.len(), 1);
205 Ok(())
206 }
207
208 #[test]
209 fn test_parse_notebook_with_missing_optional_fields() {
210 let json = r#"{
211 "cells": [{
212 "cell_type": "code",
213 "id": "test",
214 "metadata": {},
215 "source": []
216 }],
217 "metadata": {},
218 "nbformat": 4,
219 "nbformat_minor": 5
220 }"#;
221
222 let result: Result<Notebook, _> = json.parse();
223 assert!(result.is_ok());
225 }
226
227 #[test]
228 fn test_parse_malformed_json() {
229 let json = "{ invalid json }";
230 let result: Result<Notebook, _> = json.parse();
231 assert!(result.is_err());
232 }
233
234 #[test]
235 fn test_unsupported_version() {
236 let json = r#"{
237 "cells": [],
238 "metadata": {},
239 "nbformat": 3,
240 "nbformat_minor": 0
241 }"#;
242
243 let result: Result<Notebook, _> = json.parse();
244 assert!(result.is_err());
245
246 assert!(matches!(
247 result,
248 Err(NotebookError::UnsupportedVersion(3, 0))
249 ));
250 }
251
252 #[test]
253 fn test_serialize_notebook() -> Result<(), Box<dyn std::error::Error>> {
254 let notebook = create_minimal_notebook();
255 let json = notebook.to_string()?;
256
257 let parsed: serde_json::Value = serde_json::from_str(&json)?;
259 assert_eq!(parsed["nbformat"], 4);
260 assert_eq!(parsed["nbformat_minor"], 5);
261 Ok(())
262 }
263
264 #[test]
265 fn test_roundtrip_serialization() -> Result<(), Box<dyn std::error::Error>> {
266 let json = create_test_notebook_json();
267 let notebook1: Notebook = json.parse()?;
268 let serialized = notebook1.to_string()?;
269 let notebook2: Notebook = serialized.parse()?;
270
271 assert_eq!(notebook1.cells.len(), notebook2.cells.len());
272 assert_eq!(notebook1.nbformat, notebook2.nbformat);
273 Ok(())
274 }
275
276 #[test]
277 fn test_get_cell() -> Result<(), Box<dyn std::error::Error>> {
278 let json = create_test_notebook_json();
279 let notebook: Notebook = json.parse()?;
280
281 let cell = notebook.get_cell(0)?;
282 assert!(matches!(cell, Cell::Code(_)));
283 Ok(())
284 }
285
286 #[test]
287 fn test_get_cell_out_of_bounds() {
288 let notebook = create_minimal_notebook();
289 let result = notebook.get_cell(0);
290
291 assert!(result.is_err());
292 assert!(matches!(
293 result,
294 Err(NotebookError::CellIndexOutOfBounds(0, 0))
295 ));
296 }
297
298 #[test]
299 fn test_cell_count() -> Result<(), Box<dyn std::error::Error>> {
300 let notebook = create_minimal_notebook();
301 assert_eq!(notebook.cell_count(), 0);
302
303 let json = create_test_notebook_json();
304 let notebook: Notebook = json.parse()?;
305 assert_eq!(notebook.cell_count(), 1);
306 Ok(())
307 }
308
309 #[test]
310 fn test_code_cell_count() -> Result<(), Box<dyn std::error::Error>> {
311 let json = r#"{
312 "cells": [
313 {
314 "cell_type": "code",
315 "id": "1",
316 "metadata": {},
317 "source": []
318 },
319 {
320 "cell_type": "markdown",
321 "id": "2",
322 "metadata": {},
323 "source": []
324 },
325 {
326 "cell_type": "code",
327 "id": "3",
328 "metadata": {},
329 "source": []
330 }
331 ],
332 "metadata": {},
333 "nbformat": 4,
334 "nbformat_minor": 5
335 }"#;
336
337 let notebook: Notebook = json.parse()?;
338 assert_eq!(notebook.code_cell_count(), 2);
339 assert_eq!(notebook.cell_count(), 3);
340 Ok(())
341 }
342
343 #[test]
344 fn test_preserve_cell_metadata() -> Result<(), Box<dyn std::error::Error>> {
345 let json = r#"{
346 "cells": [{
347 "cell_type": "code",
348 "id": "test",
349 "metadata": {
350 "tags": ["important"],
351 "custom": "value"
352 },
353 "source": []
354 }],
355 "metadata": {},
356 "nbformat": 4,
357 "nbformat_minor": 5
358 }"#;
359
360 let notebook: Notebook = json.parse()?;
361 let serialized = notebook.to_string()?;
362
363 assert!(serialized.contains("important"));
365 assert!(serialized.contains("custom"));
366 Ok(())
367 }
368
369 #[test]
370 fn test_empty_notebook() {
371 let notebook = create_minimal_notebook();
372 assert_eq!(notebook.cells.len(), 0);
373 assert_eq!(notebook.code_cell_count(), 0);
374 }
375
376 #[test]
377 fn test_notebook_with_kernel_metadata() -> Result<(), Box<dyn std::error::Error>> {
378 let json = r#"{
379 "cells": [],
380 "metadata": {
381 "kernelspec": {
382 "name": "python3",
383 "display_name": "Python 3",
384 "language": "python"
385 },
386 "language_info": {
387 "name": "python",
388 "version": "3.11.0"
389 }
390 },
391 "nbformat": 4,
392 "nbformat_minor": 5
393 }"#;
394
395 let notebook: Notebook = json.parse()?;
396 assert!(notebook.metadata.kernelspec.is_some());
397 assert!(notebook.metadata.language_info.is_some());
398
399 if let Some(kernelspec) = notebook.metadata.kernelspec {
400 assert_eq!(kernelspec.name, "python3");
401 assert_eq!(kernelspec.display_name, "Python 3");
402 }
403 Ok(())
404 }
405
406 #[test]
407 fn test_negative_execution_count() -> Result<(), Box<dyn std::error::Error>> {
408 let json = r#"{
409 "cells": [
410 {
411 "cell_type": "code",
412 "execution_count": -1,
413 "id": "test-negative",
414 "metadata": {},
415 "outputs": [],
416 "source": ["print('test')"]
417 }
418 ],
419 "metadata": {},
420 "nbformat": 4,
421 "nbformat_minor": 5
422 }"#;
423
424 let notebook: Notebook = json.parse()?;
425 assert_eq!(notebook.cells.len(), 1);
426
427 assert!(
428 matches!(
429 ¬ebook.cells[0],
430 Cell::Code(code_cell) if code_cell.execution_count == Some(-1)
431 ),
432 "Expected code cell with execution_count -1"
433 );
434 Ok(())
435 }
436}