1use base64;
2use base64::Engine;
3use std::collections::hash_map::DefaultHasher;
4
5use anyhow::Result;
6
7use crate::ast::Ast;
8use crate::document::{CodeOutput, Document, Image, Metadata};
9
10use crate::document;
11use linked_hash_map::LinkedHashMap;
12use nanoid::nanoid;
13use serde::de::Error;
14use serde::{Deserialize, Deserializer, Serialize, Serializer};
15use serde_json::Value;
16use serde_with::{formats::PreferOne, serde_as, EnumMap, OneOrMany};
17use std::collections::HashMap;
18use std::default::Default;
19use std::hash::{Hash, Hasher};
20use std::io::{BufWriter, Write};
21
22#[derive(Serialize, Deserialize, Debug, Clone, Default)]
25pub struct Notebook {
26 pub metadata: NotebookMeta,
28 #[serde(default = "nbformat")]
29 pub nbformat: i64,
31 #[serde(default = "nbformat_minor")]
33 pub nbformat_minor: i64,
34 pub cells: Vec<Cell>,
36}
37
38const fn nbformat() -> i64 {
39 4
40}
41
42const fn nbformat_minor() -> i64 {
43 5
44}
45
46#[derive(Serialize, Deserialize, Debug, Clone)]
48#[serde(tag = "cell_type")]
49pub enum Cell {
50 #[serde(rename = "markdown")]
52 Markdown {
53 #[serde(flatten)]
54 common: CellCommon,
55 },
56 #[serde(rename = "code")]
58 Code {
59 #[serde(flatten)]
60 common: CellCommon,
61
62 execution_count: Option<i64>,
64 outputs: Vec<CellOutput>,
66 },
67 #[serde(rename = "raw")]
70 Raw {
71 #[serde(flatten)]
72 common: CellCommon,
73 },
74}
75
76#[derive(Serialize, Deserialize, Debug, Clone)]
78pub struct CellCommon {
79 #[serde(default = "get_id")]
80 pub id: String,
81 pub metadata: CellMeta,
82 #[serde(
85 deserialize_with = "concatenate_deserialize",
86 serialize_with = "concatenate_serialize"
87 )]
88 pub source: String,
89}
90
91fn get_id() -> String {
92 nanoid!()
93}
94
95#[derive(Serialize, Deserialize, Debug, Clone)]
97#[serde(rename_all = "lowercase")]
98pub enum StreamType {
99 StdOut,
100 StdErr,
101}
102
103impl ToString for StreamType {
104 fn to_string(&self) -> String {
105 match self {
106 StreamType::StdOut => "stdout".to_string(),
107 StreamType::StdErr => "stderr".to_string(),
108 }
109 }
110}
111
112#[serde_as]
115#[derive(Serialize, Deserialize, Debug, Clone)]
116#[serde(tag = "output_type")]
117pub enum CellOutput {
118 #[serde(rename = "stream")]
120 Stream {
121 name: StreamType,
122 #[serde(deserialize_with = "concatenate_deserialize")]
123 text: String,
124 },
125 #[serde(rename = "display_data", alias = "execute_result")]
127 Data {
128 #[serde(default)]
129 execution_count: i64,
130 #[serde_as(as = "EnumMap")]
132 data: Vec<OutputValue>,
133 metadata: LinkedHashMap<String, Value>,
134 },
135 #[serde(rename = "error")]
136 Error {
137 ename: String,
138 evalue: String,
139 traceback: Vec<String>,
140 },
141}
142
143#[serde_as]
145#[derive(Serialize, Deserialize, Debug, Clone)]
146pub enum OutputValue {
147 #[serde(rename = "text/plain")]
149 Plain(
150 #[serde_as(
152 deserialize_as = "OneOrMany<_, PreferOne>",
153 serialize_as = "OneOrMany<_, PreferOne>"
154 )]
155 Vec<String>,
156 ),
157 #[serde(rename = "image/png")]
159 Image(
160 #[serde_as(
161 deserialize_as = "OneOrMany<_, PreferOne>",
162 serialize_as = "OneOrMany<_, PreferOne>"
163 )]
164 Vec<String>,
165 ),
166 #[serde(rename = "image/svg+xml")]
168 Svg(
169 #[serde_as(
170 deserialize_as = "OneOrMany<_, PreferOne>",
171 serialize_as = "OneOrMany<_, PreferOne>"
172 )]
173 Vec<String>,
174 ),
175 #[serde(rename = "application/json")]
177 Json(Value),
178 #[serde(rename = "text/html")]
180 Html(
181 #[serde_as(
182 deserialize_as = "OneOrMany<_, PreferOne>",
183 serialize_as = "OneOrMany<_, PreferOne>"
184 )]
185 Vec<String>,
186 ),
187 #[serde(rename = "application/javascript")]
189 Javascript(String),
190}
191
192type Dict = HashMap<String, Value>;
193
194#[derive(Serialize, Deserialize, Debug, Clone, Default)]
196pub struct NotebookMeta {
197 pub kernelspec: Option<LinkedHashMap<String, Value>>,
199 #[serde(flatten)]
200 pub optional: Dict,
201}
202
203#[serde_with::skip_serializing_none]
205#[derive(Serialize, Deserialize, Debug, Clone, Default)]
206pub struct CellMeta {
207 pub collapsed: Option<bool>,
209 pub autoscroll: Option<Value>,
210 pub deletable: Option<bool>,
211 pub jupyter: Option<JupyterLabMeta>,
213 pub format: Option<String>,
214 pub name: Option<String>,
215 pub tags: Option<Vec<String>>,
217 #[serde(flatten)]
218 pub additional: Dict,
219}
220
221#[serde_with::skip_serializing_none]
223#[derive(Serialize, Deserialize, Debug, Clone, Default)]
224pub struct JupyterLabMeta {
225 pub outputs_hidden: Option<bool>,
227 pub source_hidden: Option<bool>,
229}
230
231impl Notebook {
232 pub fn get_front_matter(&self) -> Result<Metadata, serde_yaml::Error> {
234 match &self.cells[0] {
235 Cell::Raw { common } => Ok(serde_yaml::from_str(&common.source)?),
236 _ => Ok(Metadata::default()),
237 }
238 }
239}
240
241fn concatenate_deserialize<'de, D>(input: D) -> Result<String, D::Error>
242where
243 D: Deserializer<'de>,
244{
245 let base: Vec<String> = Deserialize::deserialize(input)?;
246 let source: String = base.into_iter().collect();
247 Ok(source)
249}
250
251fn concatenate_serialize<S>(value: &str, serializer: S) -> Result<S::Ok, S::Error>
252where
253 S: Serializer,
254{
255 let lines: Vec<&str> = value.split('\n').collect();
256 let last = lines[lines.len() - 1];
257 let mut new_lines: Vec<String> = lines[..lines.len() - 1]
258 .iter()
259 .map(|s| format!("{}\n", s))
260 .collect();
261 new_lines.push(last.to_string());
262 serializer.collect_seq(new_lines)
263}
264
265#[allow(unused)]
266fn deserialize_png<'de, D>(input: D) -> Result<Vec<u8>, D::Error>
267where
268 D: Deserializer<'de>,
269{
270 let base: String = Deserialize::deserialize(input)?;
271 let engine = base64::engine::general_purpose::STANDARD;
272 let bytes = engine
273 .decode(base)
274 .map_err(|e| D::Error::custom(e.to_string()))?;
275 Ok(bytes)
277}
278
279#[allow(unused)]
280fn serialize_png<S>(value: &Vec<u8>, serializer: S) -> Result<S::Ok, S::Error>
281where
282 S: Serializer,
283{
284 let engine = base64::engine::general_purpose::STANDARD;
285 serializer.collect_str(&engine.encode(value))
286}
287
288impl From<Vec<CellOutput>> for CodeOutput {
289 fn from(value: Vec<CellOutput>) -> Self {
290 let mut outputs = Vec::new();
291 for output in value {
292 match output {
293 CellOutput::Stream { text, .. } => {
294 outputs.push(document::OutputValue::Text(text));
295 }
296 CellOutput::Data { data, .. } => {
297 for v in data {
298 match v {
299 OutputValue::Plain(s) => {
300 outputs.push(document::OutputValue::Plain(s.join("")));
301 }
302 OutputValue::Image(i) => {
303 outputs.push(document::OutputValue::Image(Image::Png(i.join(""))));
304 }
305 OutputValue::Svg(i) => {
306 outputs.push(document::OutputValue::Image(Image::Svg(i.join(""))));
307 }
308 OutputValue::Json(s) => {
309 outputs.push(document::OutputValue::Json(s));
310 }
311 OutputValue::Html(s) => {
312 outputs.push(document::OutputValue::Html(s.join("")));
313 }
314 OutputValue::Javascript(s) => {
315 outputs.push(document::OutputValue::Javascript(s));
316 }
317 }
318 }
319 }
320 CellOutput::Error { evalue, .. } => {
321 outputs.push(document::OutputValue::Error(evalue));
322 }
323 }
324 }
325
326 CodeOutput { values: outputs }
327 }
328}
329
330pub fn notebook_to_doc(nb: Notebook, accept_draft: bool) -> Result<Option<Document<Ast>>> {
331 let mut writer = BufWriter::new(Vec::new());
332
333 let mut output_map = HashMap::new();
334
335 let mut doc_meta = None;
336
337 for cell in nb.cells {
338 match &cell {
339 Cell::Markdown { common } => {
340 write!(&mut writer, "\n{}\n", common.source)?;
341 }
342 Cell::Code {
343 common, outputs, ..
344 } => {
345 let attr = common
346 .metadata
347 .tags
348 .as_ref()
349 .map(|tags| tags.join(", "))
350 .unwrap_or(String::new());
351 let full = format!("#| tags: {}\n{}\n", attr, common.source);
352
353 write!(&mut writer, "\n```python, cell\n{}```\n", full)?;
354
355 let mut hasher = DefaultHasher::new();
356 full.hash(&mut hasher);
357 output_map.insert(hasher.finish(), CodeOutput::from(outputs.clone()));
358 }
359 Cell::Raw { common } => {
360 if let Ok(meta) = serde_yaml::from_str::<Metadata>(&common.source) {
361 if !accept_draft && meta.draft {
362 return Ok(None);
363 } else {
364 doc_meta = Some(meta);
365 }
366 }
367 }
368 }
369 }
370
371 let source = String::from_utf8(writer.into_inner()?)?;
372 let mut doc = Document::try_from(source.as_str())?;
375 doc.code_outputs = output_map;
376 doc.meta = doc_meta.unwrap_or_default();
377
378 Ok(Some(doc))
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384
385 use crate::ast;
386 use crate::ast::{Block, Command, Inline};
387 use crate::code_ast::types::{CodeContent, CodeElem};
388 use crate::common::Span;
389 use std::fs::File;
390 use std::io::BufReader;
391 use std::path::PathBuf;
392
393 #[test]
394 fn deserialize() {
395 let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
396 d.push("resources/test_deserialize.ipynb");
397 let bf = BufReader::new(File::open(d).expect("Could not open file"));
398 let _nb: Notebook = serde_json::from_reader(bf).expect("Deserialization failed");
399
400 println!("Done");
401 }
402
403 #[test]
404 fn notebook_to_doc() {
405 let nb = Notebook {
406 metadata: Default::default(),
407 nbformat: 0,
408 nbformat_minor: 0,
409 cells: vec![
410 Cell::Markdown {
411 common: CellCommon {
412 id: "id".to_string(),
413 metadata: Default::default(),
414 source: "# Heading\n#func".to_string(),
415 },
416 },
417 Cell::Code {
418 common: CellCommon {
419 id: "id".to_string(),
420 metadata: Default::default(),
421 source: "print('x')".to_string(),
422 },
423 execution_count: None,
424 outputs: vec![CellOutput::Data {
425 execution_count: 0,
426 data: vec![OutputValue::Plain(vec!["x".to_string()])],
427 metadata: Default::default(),
428 }],
429 },
430 ],
431 };
432
433 let expected = Document {
434 meta: Default::default(),
435 content: Ast {
436 blocks: vec![
437 Block::Heading {
438 lvl: 1,
439 id: None,
440 classes: vec![],
441 inner: vec![Inline::Text("Heading".into())],
442 },
443 Block::Plain(vec![Inline::Command(Command {
444 function: "func".into(),
445 label: None,
446 parameters: vec![],
447 body: None,
448 span: Span::new(11, 16),
449 global_idx: 0,
450 })]),
451 Block::Plain(vec![Inline::CodeBlock(ast::CodeBlock {
452 label: None,
453 source: CodeContent {
454 blocks: vec![CodeElem::Src("print('x')\n\n".into())],
455 meta: LinkedHashMap::from_iter(
456 [("tags".into(), "".into())].into_iter(),
457 ),
458 hash: 14521985544978239724,
459 },
460 attributes: vec!["python".into(), "cell".into()],
461 display_cell: false,
462 global_idx: 0,
463 span: Span::new(18, 58),
464 })]),
465 ],
466 source: "\n# Heading\n#func\n\n```python, cell\n#| tags: \nprint('x')\n```\n"
467 .into(),
468 },
469 code_outputs: HashMap::from([(
470 14521985544978239724,
471 CodeOutput {
472 values: vec![document::OutputValue::Plain("x".into())],
473 },
474 )]),
475 };
476 let parsed = super::notebook_to_doc(nb, true)
477 .expect("parsing errors")
478 .unwrap();
479
480 assert_eq!(expected, parsed);
481 }
482}