jtool_notebook/
notebook.rs

1use crate::cell::Cell;
2use crate::error::NotebookError;
3use crate::timing::{CellTiming, TimingSummary};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::fs;
7use std::path::Path;
8use std::str::FromStr;
9
10/// A Jupyter notebook
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Notebook {
13    /// Notebook cells
14    pub cells: Vec<Cell>,
15    /// Notebook-level metadata
16    pub metadata: NotebookMetadata,
17    /// nbformat version (major)
18    pub nbformat: u32,
19    /// nbformat version (minor)
20    pub nbformat_minor: u32,
21}
22
23impl Notebook {
24    /// Read a notebook from a file
25    pub fn from_file(path: &Path) -> Result<Self, NotebookError> {
26        let contents = fs::read_to_string(path)?;
27        contents.parse()
28    }
29
30    /// Write a notebook to a file
31    pub fn to_file(&self, path: &Path) -> Result<(), NotebookError> {
32        let contents = serde_json::to_string_pretty(self)?;
33        fs::write(path, contents)?;
34        Ok(())
35    }
36
37    /// Convert notebook to JSON string
38    pub fn to_string(&self) -> Result<String, NotebookError> {
39        Ok(serde_json::to_string_pretty(self)?)
40    }
41
42    /// Get a cell by index
43    pub fn get_cell(&self, index: usize) -> Result<&Cell, NotebookError> {
44        self.cells
45            .get(index)
46            .ok_or(NotebookError::CellIndexOutOfBounds(index, self.cells.len()))
47    }
48
49    /// Get a mutable cell by index
50    pub fn get_cell_mut(&mut self, index: usize) -> Result<&mut Cell, NotebookError> {
51        let len = self.cells.len();
52        self.cells
53            .get_mut(index)
54            .ok_or(NotebookError::CellIndexOutOfBounds(index, len))
55    }
56
57    /// Get multiple cells by indices
58    pub fn get_cells(&self, indices: &[usize]) -> Result<Vec<&Cell>, NotebookError> {
59        indices.iter().map(|&i| self.get_cell(i)).collect()
60    }
61
62    /// Count cells
63    pub fn cell_count(&self) -> usize {
64        self.cells.len()
65    }
66
67    /// Count code cells
68    pub fn code_cell_count(&self) -> usize {
69        self.cells.iter().filter(|c| c.is_code()).count()
70    }
71
72    /// Get timing information for all cells
73    ///
74    /// Returns a vector of CellTiming for all code cells that have timing metadata.
75    /// Cells without timing metadata are skipped.
76    pub fn cell_timings(&self) -> Vec<CellTiming> {
77        self.cells
78            .iter()
79            .enumerate()
80            .filter_map(|(i, cell)| {
81                if let Cell::Code(code_cell) = cell {
82                    let timing = code_cell.timing()?;
83                    Some(CellTiming::new(i, code_cell.execution_count, timing))
84                } else {
85                    None
86                }
87            })
88            .collect()
89    }
90
91    /// Get timing summary for the notebook
92    ///
93    /// Returns timing statistics (total, average, min, max) for all cells with timing data.
94    /// Returns None if no cells have timing metadata.
95    pub fn timing_summary(&self) -> Option<TimingSummary> {
96        let timings = self.cell_timings();
97        TimingSummary::from_cells(&timings)
98    }
99
100    /// Count cells with timing metadata
101    pub fn cells_with_timing_count(&self) -> usize {
102        self.cells
103            .iter()
104            .filter(|cell| {
105                if let Cell::Code(code_cell) = cell {
106                    code_cell.has_timing()
107                } else {
108                    false
109                }
110            })
111            .count()
112    }
113}
114
115impl FromStr for Notebook {
116    type Err = NotebookError;
117
118    fn from_str(s: &str) -> Result<Self, Self::Err> {
119        let notebook: Notebook = serde_json::from_str(s)?;
120
121        // Validate nbformat version
122        if notebook.nbformat < 4 {
123            return Err(NotebookError::UnsupportedVersion(
124                notebook.nbformat,
125                notebook.nbformat_minor,
126            ));
127        }
128
129        Ok(notebook)
130    }
131}
132
133/// Notebook metadata
134#[derive(Debug, Clone, Default, Serialize, Deserialize)]
135pub struct NotebookMetadata {
136    /// Kernel specification
137    #[serde(skip_serializing_if = "Option::is_none")]
138    pub kernelspec: Option<KernelSpec>,
139    /// Language info
140    #[serde(skip_serializing_if = "Option::is_none")]
141    pub language_info: Option<LanguageInfo>,
142    /// Additional metadata (extensible)
143    #[serde(flatten)]
144    pub extra: HashMap<String, serde_json::Value>,
145}
146
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct KernelSpec {
149    pub name: String,
150    pub display_name: String,
151    #[serde(skip_serializing_if = "Option::is_none")]
152    pub language: Option<String>,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct LanguageInfo {
157    pub name: String,
158    #[serde(skip_serializing_if = "Option::is_none")]
159    pub version: Option<String>,
160    #[serde(skip_serializing_if = "Option::is_none")]
161    pub file_extension: Option<String>,
162    #[serde(flatten)]
163    pub extra: HashMap<String, serde_json::Value>,
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    fn create_minimal_notebook() -> Notebook {
171        Notebook {
172            cells: vec![],
173            metadata: NotebookMetadata::default(),
174            nbformat: 4,
175            nbformat_minor: 5,
176        }
177    }
178
179    fn create_test_notebook_json() -> &'static str {
180        r#"{
181            "cells": [
182                {
183                    "cell_type": "code",
184                    "execution_count": null,
185                    "id": "test-1",
186                    "metadata": {},
187                    "outputs": [],
188                    "source": ["print('hello')"]
189                }
190            ],
191            "metadata": {},
192            "nbformat": 4,
193            "nbformat_minor": 5
194        }"#
195    }
196
197    #[test]
198    fn test_parse_valid_notebook() -> Result<(), Box<dyn std::error::Error>> {
199        let json = create_test_notebook_json();
200        let notebook: Notebook = json.parse()?;
201
202        assert_eq!(notebook.nbformat, 4);
203        assert_eq!(notebook.nbformat_minor, 5);
204        assert_eq!(notebook.cells.len(), 1);
205        Ok(())
206    }
207
208    #[test]
209    fn test_parse_notebook_with_missing_optional_fields() {
210        let json = r#"{
211            "cells": [{
212                "cell_type": "code",
213                "id": "test",
214                "metadata": {},
215                "source": []
216            }],
217            "metadata": {},
218            "nbformat": 4,
219            "nbformat_minor": 5
220        }"#;
221
222        let result: Result<Notebook, _> = json.parse();
223        // Should handle missing optional fields like execution_count, outputs
224        assert!(result.is_ok());
225    }
226
227    #[test]
228    fn test_parse_malformed_json() {
229        let json = "{ invalid json }";
230        let result: Result<Notebook, _> = json.parse();
231        assert!(result.is_err());
232    }
233
234    #[test]
235    fn test_unsupported_version() {
236        let json = r#"{
237            "cells": [],
238            "metadata": {},
239            "nbformat": 3,
240            "nbformat_minor": 0
241        }"#;
242
243        let result: Result<Notebook, _> = json.parse();
244        assert!(result.is_err());
245
246        assert!(matches!(
247            result,
248            Err(NotebookError::UnsupportedVersion(3, 0))
249        ));
250    }
251
252    #[test]
253    fn test_serialize_notebook() -> Result<(), Box<dyn std::error::Error>> {
254        let notebook = create_minimal_notebook();
255        let json = notebook.to_string()?;
256
257        // Should be valid JSON
258        let parsed: serde_json::Value = serde_json::from_str(&json)?;
259        assert_eq!(parsed["nbformat"], 4);
260        assert_eq!(parsed["nbformat_minor"], 5);
261        Ok(())
262    }
263
264    #[test]
265    fn test_roundtrip_serialization() -> Result<(), Box<dyn std::error::Error>> {
266        let json = create_test_notebook_json();
267        let notebook1: Notebook = json.parse()?;
268        let serialized = notebook1.to_string()?;
269        let notebook2: Notebook = serialized.parse()?;
270
271        assert_eq!(notebook1.cells.len(), notebook2.cells.len());
272        assert_eq!(notebook1.nbformat, notebook2.nbformat);
273        Ok(())
274    }
275
276    #[test]
277    fn test_get_cell() -> Result<(), Box<dyn std::error::Error>> {
278        let json = create_test_notebook_json();
279        let notebook: Notebook = json.parse()?;
280
281        let cell = notebook.get_cell(0)?;
282        assert!(matches!(cell, Cell::Code(_)));
283        Ok(())
284    }
285
286    #[test]
287    fn test_get_cell_out_of_bounds() {
288        let notebook = create_minimal_notebook();
289        let result = notebook.get_cell(0);
290
291        assert!(result.is_err());
292        assert!(matches!(
293            result,
294            Err(NotebookError::CellIndexOutOfBounds(0, 0))
295        ));
296    }
297
298    #[test]
299    fn test_cell_count() -> Result<(), Box<dyn std::error::Error>> {
300        let notebook = create_minimal_notebook();
301        assert_eq!(notebook.cell_count(), 0);
302
303        let json = create_test_notebook_json();
304        let notebook: Notebook = json.parse()?;
305        assert_eq!(notebook.cell_count(), 1);
306        Ok(())
307    }
308
309    #[test]
310    fn test_code_cell_count() -> Result<(), Box<dyn std::error::Error>> {
311        let json = r#"{
312            "cells": [
313                {
314                    "cell_type": "code",
315                    "id": "1",
316                    "metadata": {},
317                    "source": []
318                },
319                {
320                    "cell_type": "markdown",
321                    "id": "2",
322                    "metadata": {},
323                    "source": []
324                },
325                {
326                    "cell_type": "code",
327                    "id": "3",
328                    "metadata": {},
329                    "source": []
330                }
331            ],
332            "metadata": {},
333            "nbformat": 4,
334            "nbformat_minor": 5
335        }"#;
336
337        let notebook: Notebook = json.parse()?;
338        assert_eq!(notebook.code_cell_count(), 2);
339        assert_eq!(notebook.cell_count(), 3);
340        Ok(())
341    }
342
343    #[test]
344    fn test_preserve_cell_metadata() -> Result<(), Box<dyn std::error::Error>> {
345        let json = r#"{
346            "cells": [{
347                "cell_type": "code",
348                "id": "test",
349                "metadata": {
350                    "tags": ["important"],
351                    "custom": "value"
352                },
353                "source": []
354            }],
355            "metadata": {},
356            "nbformat": 4,
357            "nbformat_minor": 5
358        }"#;
359
360        let notebook: Notebook = json.parse()?;
361        let serialized = notebook.to_string()?;
362
363        // Metadata should be preserved
364        assert!(serialized.contains("important"));
365        assert!(serialized.contains("custom"));
366        Ok(())
367    }
368
369    #[test]
370    fn test_empty_notebook() {
371        let notebook = create_minimal_notebook();
372        assert_eq!(notebook.cells.len(), 0);
373        assert_eq!(notebook.code_cell_count(), 0);
374    }
375
376    #[test]
377    fn test_notebook_with_kernel_metadata() -> Result<(), Box<dyn std::error::Error>> {
378        let json = r#"{
379            "cells": [],
380            "metadata": {
381                "kernelspec": {
382                    "name": "python3",
383                    "display_name": "Python 3",
384                    "language": "python"
385                },
386                "language_info": {
387                    "name": "python",
388                    "version": "3.11.0"
389                }
390            },
391            "nbformat": 4,
392            "nbformat_minor": 5
393        }"#;
394
395        let notebook: Notebook = json.parse()?;
396        assert!(notebook.metadata.kernelspec.is_some());
397        assert!(notebook.metadata.language_info.is_some());
398
399        if let Some(kernelspec) = notebook.metadata.kernelspec {
400            assert_eq!(kernelspec.name, "python3");
401            assert_eq!(kernelspec.display_name, "Python 3");
402        }
403        Ok(())
404    }
405
406    #[test]
407    fn test_negative_execution_count() -> Result<(), Box<dyn std::error::Error>> {
408        let json = r#"{
409            "cells": [
410                {
411                    "cell_type": "code",
412                    "execution_count": -1,
413                    "id": "test-negative",
414                    "metadata": {},
415                    "outputs": [],
416                    "source": ["print('test')"]
417                }
418            ],
419            "metadata": {},
420            "nbformat": 4,
421            "nbformat_minor": 5
422        }"#;
423
424        let notebook: Notebook = json.parse()?;
425        assert_eq!(notebook.cells.len(), 1);
426
427        assert!(
428            matches!(
429                &notebook.cells[0],
430                Cell::Code(code_cell) if code_cell.execution_count == Some(-1)
431            ),
432            "Expected code cell with execution_count -1"
433        );
434        Ok(())
435    }
436}