use std::future::Future;
use futures::StreamExt;
use serde::{Deserialize, Serialize};
use yaaral::prelude::*;
use crate::prelude::*;
#[derive(Debug)]
pub struct SimpleApi<RT>
where
RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
{
runtime: RT,
baseuri: String,
port: u16,
model: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct Message
{
role: String,
content: String,
}
#[derive(Debug, Serialize)]
struct ChatRequestBody
{
#[serde(skip_serializing_if = "Option::is_none")]
model: Option<String>,
messages: Vec<Message>,
temperature: f32,
max_tokens: u32,
stream: bool,
grammar: Option<String>,
}
#[derive(Debug, Serialize)]
struct GenerationRequestBody
{
#[serde(skip_serializing_if = "Option::is_none")]
model: Option<String>,
prompt: String,
temperature: f32,
max_tokens: u32,
stream: bool,
grammar: Option<String>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct ChatChunk
{
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
#[serde(default)]
pub timings: Option<Timings>, }
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct Choice
{
pub index: u32,
#[serde(default)]
pub finish_reason: Option<String>,
#[serde(default)]
pub delta: Delta,
}
#[derive(Debug, Deserialize, Default)]
#[allow(dead_code)]
struct Delta
{
#[serde(default)]
pub role: Option<String>,
#[serde(default)]
pub content: Option<String>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct Timings
{
pub prompt_n: Option<i32>,
pub prompt_ms: Option<f64>,
pub prompt_per_token_ms: Option<f64>,
pub prompt_per_second: Option<f64>,
pub predicted_n: Option<i32>,
pub predicted_ms: Option<f64>,
pub predicted_per_token_ms: Option<f64>,
pub predicted_per_second: Option<f64>,
}
#[derive(Debug, Deserialize, Serialize)]
#[serde(untagged)]
#[allow(clippy::large_enum_variant)]
enum StreamFrame
{
Delta(GenerationDelta),
Final(Final),
}
#[derive(Debug, Deserialize, Serialize)]
struct GenerationDelta
{
pub content: String,
pub stop: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub oaicompat_msg_diffs: Option<Vec<OaiDelta>>,
}
#[derive(Debug, Deserialize, Serialize)]
struct OaiDelta
{
pub content_delta: String,
}
#[derive(Debug, Deserialize, Serialize)]
struct Final
{
pub content: String,
pub generated_text: String,
pub stop: bool,
pub model: String,
pub tokens_predicted: u64,
pub tokens_evaluated: u64,
pub generation_settings: serde_json::Value,
pub prompt: String,
pub truncated: bool,
pub stopped_eos: bool,
pub stopped_word: bool,
pub stopped_limit: bool,
pub tokens_cached: u64,
pub timings: serde_json::Value,
}
#[derive(Debug, Deserialize, Serialize)]
pub(crate) struct ApiError
{
pub code: Option<u32>,
pub message: String,
#[serde(rename = "type")]
pub typ: String,
}
impl<RT> SimpleApi<RT>
where
RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
{
pub fn new(
runtime: RT,
baseuri: impl Into<String>,
port: u16,
model: Option<String>,
) -> Result<Self>
{
Ok(Self {
baseuri: baseuri.into(),
port,
model,
runtime,
})
}
}
trait Data
{
fn content(&self) -> Option<&String>;
fn is_finished(&self) -> bool;
}
impl Data for ChatChunk
{
fn content(&self) -> Option<&String>
{
self.choices.first().and_then(|c| c.delta.content.as_ref())
}
fn is_finished(&self) -> bool
{
if let Some(reason) = self
.choices
.first()
.and_then(|c| c.finish_reason.as_deref())
{
if reason == "stop"
{
return true;
}
}
false
}
}
impl Data for StreamFrame
{
fn content(&self) -> Option<&String>
{
match self
{
StreamFrame::Delta(delta) => Some(&delta.content),
StreamFrame::Final(_) => None,
}
}
fn is_finished(&self) -> bool
{
match self
{
Self::Delta(_) => false,
Self::Final(_) => true,
}
}
}
fn response_to_stream<D: Data + for<'de> Deserialize<'de>>(
response: impl yaaral::http::Response,
) -> Result<StringStream>
{
let stream = response.into_stream().map(|chunk_result| {
let mut results = vec![];
match chunk_result
{
Ok(chunk) =>
{
let chunk_str = String::from_utf8_lossy(&chunk);
for line in chunk_str.lines()
{
let line = line.trim();
if line.starts_with("data:")
{
let json_str = line.trim_start_matches("data:");
match serde_json::from_str::<D>(json_str)
{
Ok(chunk) =>
{
if let Some(content) = chunk.content()
{
results.push(Ok(content.to_owned()));
}
if chunk.is_finished()
{
break;
}
}
Err(e) => results.push(Err(e.into())),
}
}
else if line.starts_with("error:")
{
let json_str = line.trim_start_matches("error:");
if let Ok(chunk) = serde_json::from_str::<ApiError>(json_str)
{
results.push(Err(Error::SimpleApiError {
code: chunk.code.unwrap_or_default(),
message: chunk.message,
error_type: chunk.typ,
}));
}
}
else if !line.is_empty()
{
log::error!("Unhandled line: {}.", line);
}
}
}
Err(e) =>
{
results.push(Err(Error::HttpError(format!("{:?}", e))));
}
}
futures::stream::iter(results)
});
let flat_stream = stream.flatten().boxed();
Ok(pin_stream(flat_stream))
}
fn grammar_for(format: crate::Format) -> Option<String>
{
match format
{
crate::Format::Json => Some(include_str!("../data/grammar/json.gbnf").to_string()),
crate::Format::Text => None,
}
}
impl<RT> LargeLanguageModel for SimpleApi<RT>
where
RT: yaaral::TaskInterface + yaaral::http::HttpClientInterface,
{
fn chat_stream(
&self,
prompt: ChatPrompt,
) -> Result<impl Future<Output = Result<StringStream>> + Send>
{
let url = format!("{}:{}/v1/chat/completions", self.baseuri, self.port);
let messages = prompt
.messages
.into_iter()
.map(|m| Message {
role: match m.role
{
Role::User => "user".to_string(),
Role::System => "system".to_string(),
Role::Assistant => "assistant".to_string(),
Role::Custom(custom) => custom,
},
content: m.content,
})
.collect();
let request_body = ChatRequestBody {
model: self.model.to_owned(),
messages,
temperature: 0.7,
max_tokens: 2560,
stream: true,
grammar: grammar_for(prompt.format),
};
let rt = self.runtime.clone();
let yreq = yaaral::http::Request::from_uri(url).json(&request_body)?;
Ok(async move {
let response = rt.wpost(yreq).await;
if !response.status().is_success()
{
return Err(Error::HttpError(format!(
"Error code {}",
response.status()
)));
}
response_to_stream::<ChatChunk>(response)
})
}
fn generate_stream(
&self,
prompt: GenerationPrompt,
) -> Result<impl Future<Output = Result<StringStream>> + Send>
{
let rt = self.runtime.clone();
Ok(async move {
if prompt.system.is_none() && prompt.assistant.is_none()
{
let url = format!("{}:{}/v1/completions", self.baseuri, self.port);
let request_body = GenerationRequestBody {
model: self.model.to_owned(),
prompt: prompt.user,
temperature: 0.7,
max_tokens: 2560,
stream: true,
grammar: grammar_for(prompt.format),
};
let yreq = yaaral::http::Request::from_uri(url).json(&request_body)?;
let response = rt.wpost(yreq).await;
if !response.status().is_success()
{
return Err(Error::HttpError(format!(
"Error code {}",
response.status()
)));
}
response_to_stream::<StreamFrame>(response)
}
else
{
crate::generate_with_chat(self, prompt)?.await
}
})
}
}