kernel_sidecar/
notebook.rs

1/*
2Models a Notebook document. https://ipython.org/ipython-doc/3/notebook/nbformat.html
3*/
4
5use crate::jupyter::iopub_content::display_data::DisplayData;
6use crate::jupyter::iopub_content::errors::Error;
7use crate::jupyter::iopub_content::execute_result::ExecuteResult;
8use crate::jupyter::iopub_content::stream::Stream;
9use enum_as_inner::EnumAsInner;
10use serde::ser::SerializeMap;
11use serde::{Deserialize, Deserializer, Serialize, Serializer};
12
13#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
14pub struct Notebook {
15    pub cells: Vec<Cell>,
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub signature: Option<String>,
18    #[serde(
19        serialize_with = "serialize_json_value_as_empty_object",
20        deserialize_with = "serde_json::value::Value::deserialize"
21    )]
22    pub metadata: serde_json::Value,
23    pub nbformat: u32,
24    pub nbformat_minor: u32,
25}
26
27impl Default for Notebook {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33impl Notebook {
34    pub fn new() -> Self {
35        Self {
36            cells: vec![],
37            signature: None,
38            metadata: serde_json::Value::Null,
39            nbformat: 4,
40            nbformat_minor: 5,
41        }
42    }
43
44    pub fn from_file(filename: &str) -> Self {
45        let content = std::fs::read_to_string(filename).unwrap();
46        serde_json::from_str(&content).unwrap()
47    }
48
49    pub fn save(&self, filename: &str) {
50        let js = serde_json::to_string_pretty(&self).expect("Failed to serialize notebook on save");
51        std::fs::write(filename, js).unwrap();
52    }
53
54    pub fn dumps(&self) -> String {
55        serde_json::to_string_pretty(&self).expect("Failed to serialize notebook on save")
56    }
57
58    pub fn get_cell(&self, id: &str) -> Option<&Cell> {
59        self.cells.iter().find(|&cell| cell.id() == id)
60    }
61
62    pub fn get_mut_cell(&mut self, id: &str) -> Option<&mut Cell> {
63        self.cells.iter_mut().find(|cell| cell.id() == id)
64    }
65
66    pub fn add_cell(&mut self, cell: Cell) {
67        self.cells.push(cell);
68    }
69
70    pub fn add_code_cell(&mut self, source: &str) -> Cell {
71        let cell = Cell::Code(CodeCell {
72            id: uuid::Uuid::new_v4().to_string(),
73            source: source.to_owned(),
74            metadata: serde_json::Value::Null,
75            execution_count: None,
76            outputs: vec![],
77        });
78        self.cells.push(cell.clone());
79        cell
80    }
81
82    pub fn add_markdown_cell(&mut self, source: &str) -> Cell {
83        let cell = Cell::Markdown(MarkdownCell {
84            id: uuid::Uuid::new_v4().to_string(),
85            source: source.to_owned(),
86            metadata: serde_json::Value::Null,
87        });
88        self.cells.push(cell.clone());
89        cell
90    }
91}
92
93#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, EnumAsInner)]
94#[serde(tag = "output_type", rename_all = "snake_case")]
95pub enum Output {
96    DisplayData(DisplayData),
97    Stream(Stream),
98    ExecuteResult(ExecuteResult),
99    Error(Error),
100}
101
102#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
103#[serde(tag = "cell_type", rename_all = "lowercase")]
104pub enum Cell {
105    Code(CodeCell),
106    Markdown(MarkdownCell),
107    Raw(RawCell),
108}
109
110impl Cell {
111    pub fn id(&self) -> &str {
112        match self {
113            Cell::Code(cell) => &cell.id,
114            Cell::Markdown(cell) => &cell.id,
115            Cell::Raw(cell) => &cell.id,
116        }
117    }
118
119    pub fn get_source(&self) -> String {
120        match self {
121            Cell::Code(cell) => cell.source.to_string(),
122            Cell::Markdown(cell) => cell.source.to_string(),
123            Cell::Raw(cell) => cell.source.to_string(),
124        }
125    }
126
127    pub fn set_source(&mut self, source: &str) {
128        match self {
129            Cell::Code(cell) => cell.source = source.to_string(),
130            Cell::Markdown(cell) => cell.source = source.to_string(),
131            Cell::Raw(cell) => cell.source = source.to_string(),
132        }
133    }
134
135    pub fn metadata(&self) -> &serde_json::Value {
136        match self {
137            Cell::Code(cell) => &cell.metadata,
138            Cell::Markdown(cell) => &cell.metadata,
139            Cell::Raw(cell) => &cell.metadata,
140        }
141    }
142
143    pub fn add_output(&mut self, output: Output) {
144        if let Cell::Code(cell) = self {
145            cell.add_output(output);
146        }
147    }
148
149    pub fn clear_output(&mut self) {
150        if let Cell::Code(cell) = self {
151            cell.clear_output();
152        }
153    }
154}
155
156#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
157pub struct CodeCell {
158    pub id: String,
159    #[serde(deserialize_with = "list_or_string_to_string")]
160    pub source: String,
161    #[serde(
162        serialize_with = "serialize_json_value_as_empty_object",
163        deserialize_with = "serde_json::value::Value::deserialize"
164    )]
165    pub metadata: serde_json::Value,
166    #[serde(skip_serializing_if = "Option::is_none")]
167    pub execution_count: Option<u32>,
168    pub outputs: Vec<Output>,
169}
170
171impl CodeCell {
172    pub fn add_output(&mut self, output: Output) {
173        self.outputs.push(output);
174    }
175
176    pub fn clear_output(&mut self) {
177        self.outputs = vec![];
178    }
179}
180
181#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
182pub struct MarkdownCell {
183    pub id: String,
184    #[serde(deserialize_with = "list_or_string_to_string")]
185    pub source: String,
186    #[serde(
187        serialize_with = "serialize_json_value_as_empty_object",
188        deserialize_with = "serde_json::value::Value::deserialize"
189    )]
190    pub metadata: serde_json::Value,
191}
192
193#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
194pub struct RawCell {
195    pub id: String,
196    #[serde(deserialize_with = "list_or_string_to_string")]
197    pub source: String,
198    #[serde(
199        serialize_with = "serialize_json_value_as_empty_object",
200        deserialize_with = "serde_json::value::Value::deserialize"
201    )]
202    pub metadata: serde_json::Value,
203}
204
205// Custom deserialization for source field since it may be a Vec<String> or String
206pub fn list_or_string_to_string<'de, D>(deserializer: D) -> Result<String, D::Error>
207where
208    D: Deserializer<'de>,
209{
210    // Deserialize the source field as a serde_json::Value
211    let source_value: serde_json::Value = Deserialize::deserialize(deserializer)?;
212
213    // Check if the source is an array of strings
214    if let Some(source_array) = source_value.as_array() {
215        // Join the array of strings into a single string
216        let source_string = source_array
217            .iter()
218            .map(|s| s.as_str().unwrap_or_default())
219            .collect::<Vec<_>>()
220            .join("\n");
221
222        Ok(source_string)
223    } else if let Some(source_str) = source_value.as_str() {
224        // If source is already a string, return it
225        Ok(source_str.to_string())
226    } else {
227        Err(serde::de::Error::custom("Invalid source format"))
228    }
229}
230
231// Custom serialization for when metadata fields are null, make them empty objects instead
232fn serialize_json_value_as_empty_object<S>(
233    value: &serde_json::Value,
234    serializer: S,
235) -> Result<S::Ok, S::Error>
236where
237    S: Serializer,
238{
239    match value {
240        serde_json::Value::Null => serializer.serialize_map(Some(0))?.end(),
241        _ => value.serialize(serializer),
242    }
243}