pub(crate) mod request;
pub(crate) mod response;
use serde::{Deserialize, Serialize};
#[derive(Debug, Default, Serialize)]
pub(crate) struct WireOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub greedy: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u64>,
}
#[derive(Debug, Serialize)]
pub(crate) struct SessionConfig {
#[serde(skip_serializing_if = "Option::is_none")]
pub instructions: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub lora: Option<String>,
}
#[derive(Debug, Serialize)]
pub(crate) struct TurnRequest {
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<WireOptions>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct SessionCreated {
pub session: u64,
}
#[derive(Debug, Deserialize)]
pub(crate) struct CompleteReply {
pub text: String,
pub finish: String,
}
#[derive(Debug, Deserialize)]
pub(crate) struct ErrorBody {
pub kind: String,
pub message: String,
}
#[derive(Debug, Deserialize)]
pub(crate) struct ErrorReply {
pub error: ErrorBody,
}
#[cfg(feature = "stream")]
#[derive(Debug, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub(crate) enum WireStreamEvent {
Delta {
text: String,
},
Done {
text: String,
finish: String,
},
Error {
error: ErrorBody,
},
}
#[cfg(all(test, feature = "stream"))]
mod tests {
use super::*;
#[test]
fn parses_stream_events() {
let event: WireStreamEvent =
serde_json::from_str(r#"{"type":"delta","text":"hi"}"#).unwrap();
assert!(matches!(event, WireStreamEvent::Delta { text } if text == "hi"));
let event: WireStreamEvent =
serde_json::from_str(r#"{"type":"done","text":"hi there","finish":"stop"}"#).unwrap();
assert!(matches!(event, WireStreamEvent::Done { finish, .. } if finish == "stop"));
let event: WireStreamEvent =
serde_json::from_str(r#"{"type":"error","error":{"kind":"generation","message":"x"}}"#)
.unwrap();
assert!(matches!(event, WireStreamEvent::Error { error } if error.kind == "generation"));
}
}