use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum MessageType {
ExecuteRequest,
ExecuteReply,
InspectRequest,
InspectReply,
CompleteRequest,
CompleteReply,
HistoryRequest,
HistoryReply,
IsCompleteRequest,
IsCompleteReply,
KernelInfoRequest,
KernelInfoReply,
ShutdownRequest,
ShutdownReply,
InterruptRequest,
InterruptReply,
Status,
Stream,
DisplayData,
ExecuteInput,
ExecuteResult,
Error,
InputRequest,
InputReply,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JupyterMessage {
pub header: MessageHeader,
pub parent_header: Option<MessageHeader>,
pub metadata: HashMap<String, JsonValue>,
pub content: JsonValue,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub buffers: Vec<Vec<u8>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageHeader {
pub msg_id: String,
pub msg_type: MessageType,
pub session: String,
pub date: String,
pub version: String,
pub username: String,
}
impl MessageHeader {
pub fn new(msg_type: MessageType, session: &str) -> Self {
Self {
msg_id: Uuid::new_v4().to_string(),
msg_type,
session: session.to_string(),
date: chrono::Utc::now().to_rfc3339(),
version: "5.3".to_string(),
username: "kernel".to_string(),
}
}
}
impl JupyterMessage {
pub fn new(msg_type: MessageType, session: &str, content: JsonValue) -> Self {
Self {
header: MessageHeader::new(msg_type, session),
parent_header: None,
metadata: HashMap::new(),
content,
buffers: Vec::new(),
}
}
pub fn reply(parent: &JupyterMessage, msg_type: MessageType, content: JsonValue) -> Self {
Self {
header: MessageHeader::new(msg_type, &parent.header.session),
parent_header: Some(parent.header.clone()),
metadata: HashMap::new(),
content,
buffers: Vec::new(),
}
}
pub fn to_json(&self) -> serde_json::Result<String> {
serde_json::to_string(self)
}
pub fn from_json(json: &str) -> serde_json::Result<Self> {
serde_json::from_str(json)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecuteRequest {
pub code: String,
pub silent: bool,
pub store_history: bool,
#[serde(default)]
pub user_expressions: HashMap<String, String>,
pub allow_stdin: bool,
#[serde(default)]
pub stop_on_error: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecuteReply {
pub status: ExecutionStatus,
pub execution_count: u64,
#[serde(default)]
pub user_expressions: HashMap<String, JsonValue>,
#[serde(default)]
pub payload: Vec<JsonValue>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ExecutionStatus {
Ok,
Error,
Abort,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KernelInfoReply {
pub protocol_version: String,
pub implementation: String,
pub implementation_version: String,
pub language_info: LanguageInfo,
pub banner: String,
#[serde(default)]
pub debugger: bool,
#[serde(default)]
pub help_links: Vec<HelpLink>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LanguageInfo {
pub name: String,
pub version: String,
pub mimetype: String,
pub file_extension: String,
pub pygments_lexer: String,
pub codemirror_mode: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HelpLink {
pub text: String,
pub url: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Status {
pub execution_state: ExecutionState,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ExecutionState {
Starting,
Idle,
Busy,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Stream {
pub name: String,
pub text: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorContent {
pub ename: String,
pub evalue: String,
pub traceback: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecuteResult {
pub execution_count: u64,
pub data: HashMap<String, JsonValue>,
#[serde(default)]
pub metadata: HashMap<String, JsonValue>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_creation() {
let content = serde_json::json!({"code": "x = 1 + 2"});
let msg = JupyterMessage::new(MessageType::ExecuteRequest, "test-session", content);
assert_eq!(msg.header.msg_type, MessageType::ExecuteRequest);
assert_eq!(msg.header.session, "test-session");
assert!(!msg.header.msg_id.is_empty());
assert!(msg.parent_header.is_none());
}
#[test]
fn test_reply_message() {
let request_content = serde_json::json!({"code": "x = 1"});
let request = JupyterMessage::new(MessageType::ExecuteRequest, "test", request_content);
let reply_content = serde_json::json!({"status": "ok"});
let reply = JupyterMessage::reply(&request, MessageType::ExecuteReply, reply_content);
assert_eq!(reply.header.msg_type, MessageType::ExecuteReply);
assert_eq!(reply.header.session, "test");
assert!(reply.parent_header.is_some());
assert_eq!(reply.parent_header.unwrap().msg_id, request.header.msg_id);
}
#[test]
fn test_execute_request_serialization() {
let execute_req = ExecuteRequest {
code: "disp('hello')".to_string(),
silent: false,
store_history: true,
user_expressions: HashMap::new(),
allow_stdin: false,
stop_on_error: true,
};
let json = serde_json::to_string(&execute_req).unwrap();
let parsed: ExecuteRequest = serde_json::from_str(&json).unwrap();
assert_eq!(execute_req.code, parsed.code);
assert_eq!(execute_req.silent, parsed.silent);
}
#[test]
fn test_message_json_roundtrip() {
let content = serde_json::json!({
"code": "x = magic(3)",
"silent": false
});
let original = JupyterMessage::new(MessageType::ExecuteRequest, "test", content);
let json = original.to_json().unwrap();
let parsed = JupyterMessage::from_json(&json).unwrap();
assert_eq!(original.header.msg_type, parsed.header.msg_type);
assert_eq!(original.header.session, parsed.header.session);
assert_eq!(original.content, parsed.content);
}
#[test]
fn test_status_message() {
let status = Status {
execution_state: ExecutionState::Busy,
};
let content = serde_json::to_value(&status).unwrap();
let msg = JupyterMessage::new(MessageType::Status, "test", content);
assert_eq!(msg.header.msg_type, MessageType::Status);
}
}