use std::collections::HashMap;
use std::sync::Arc;
use async_openai::types::chat::{
ChatCompletionRequestMessage, ChatCompletionTools, CreateChatCompletionResponse,
CreateChatCompletionStreamResponse,
};
use serde_json::{json, Value};
use crate::detail::core_interop::CoreInterop;
use crate::error::{FoundryLocalError, Result};
use crate::types::{ChatResponseFormat, ChatToolChoice};
use super::json_stream::JsonStream;
#[derive(Debug, Clone, Default)]
pub struct ChatClientSettings {
frequency_penalty: Option<f64>,
max_tokens: Option<u32>,
n: Option<u32>,
temperature: Option<f64>,
presence_penalty: Option<f64>,
top_p: Option<f64>,
top_k: Option<u32>,
random_seed: Option<u64>,
response_format: Option<ChatResponseFormat>,
tool_choice: Option<ChatToolChoice>,
}
impl ChatClientSettings {
fn serialize(&self) -> Value {
let mut map = serde_json::Map::new();
if let Some(v) = self.frequency_penalty {
map.insert("frequency_penalty".into(), json!(v));
}
if let Some(v) = self.max_tokens {
map.insert("max_tokens".into(), json!(v));
}
if let Some(v) = self.n {
map.insert("n".into(), json!(v));
}
if let Some(v) = self.presence_penalty {
map.insert("presence_penalty".into(), json!(v));
}
if let Some(v) = self.temperature {
map.insert("temperature".into(), json!(v));
}
if let Some(v) = self.top_p {
map.insert("top_p".into(), json!(v));
}
if let Some(ref rf) = self.response_format {
let mut rf_map = serde_json::Map::new();
match rf {
ChatResponseFormat::Text => {
rf_map.insert("type".into(), json!("text"));
}
ChatResponseFormat::JsonObject => {
rf_map.insert("type".into(), json!("json_object"));
}
ChatResponseFormat::JsonSchema(schema) => {
rf_map.insert("type".into(), json!("json_schema"));
rf_map.insert("jsonSchema".into(), json!(schema));
}
ChatResponseFormat::LarkGrammar(grammar) => {
rf_map.insert("type".into(), json!("lark_grammar"));
rf_map.insert("larkGrammar".into(), json!(grammar));
}
}
map.insert("response_format".into(), Value::Object(rf_map));
}
if let Some(ref tc) = self.tool_choice {
let mut tc_map = serde_json::Map::new();
match tc {
ChatToolChoice::None => {
tc_map.insert("type".into(), json!("none"));
}
ChatToolChoice::Auto => {
tc_map.insert("type".into(), json!("auto"));
}
ChatToolChoice::Required => {
tc_map.insert("type".into(), json!("required"));
}
ChatToolChoice::Function(name) => {
tc_map.insert("type".into(), json!("function"));
tc_map.insert("name".into(), json!(name));
}
}
map.insert("tool_choice".into(), Value::Object(tc_map));
}
let mut metadata: HashMap<String, String> = HashMap::new();
if let Some(k) = self.top_k {
metadata.insert("top_k".into(), k.to_string());
}
if let Some(s) = self.random_seed {
metadata.insert("random_seed".into(), s.to_string());
}
if !metadata.is_empty() {
map.insert("metadata".into(), json!(metadata));
}
Value::Object(map)
}
}
pub type ChatCompletionStream = JsonStream<CreateChatCompletionStreamResponse>;
pub struct ChatClient {
model_id: String,
core: Arc<CoreInterop>,
settings: ChatClientSettings,
}
impl ChatClient {
pub(crate) fn new(model_id: &str, core: Arc<CoreInterop>) -> Self {
Self {
model_id: model_id.to_owned(),
core,
settings: ChatClientSettings::default(),
}
}
pub fn frequency_penalty(mut self, v: f64) -> Self {
self.settings.frequency_penalty = Some(v);
self
}
pub fn max_tokens(mut self, v: u32) -> Self {
self.settings.max_tokens = Some(v);
self
}
pub fn n(mut self, v: u32) -> Self {
self.settings.n = Some(v);
self
}
pub fn temperature(mut self, v: f64) -> Self {
self.settings.temperature = Some(v);
self
}
pub fn presence_penalty(mut self, v: f64) -> Self {
self.settings.presence_penalty = Some(v);
self
}
pub fn top_p(mut self, v: f64) -> Self {
self.settings.top_p = Some(v);
self
}
pub fn top_k(mut self, v: u32) -> Self {
self.settings.top_k = Some(v);
self
}
pub fn random_seed(mut self, v: u64) -> Self {
self.settings.random_seed = Some(v);
self
}
pub fn response_format(mut self, v: ChatResponseFormat) -> Self {
self.settings.response_format = Some(v);
self
}
pub fn tool_choice(mut self, v: ChatToolChoice) -> Self {
self.settings.tool_choice = Some(v);
self
}
pub async fn complete_chat(
&self,
messages: &[ChatCompletionRequestMessage],
tools: Option<&[ChatCompletionTools]>,
) -> Result<CreateChatCompletionResponse> {
if messages.is_empty() {
return Err(FoundryLocalError::Validation {
reason: "messages must be a non-empty array".into(),
});
}
let request = self.build_request(messages, tools, false)?;
let params = json!({
"Params": {
"OpenAICreateRequest": serde_json::to_string(&request)?
}
});
let raw = self
.core
.execute_command_async("chat_completions".into(), Some(params))
.await?;
let parsed: CreateChatCompletionResponse = serde_json::from_str(&raw)?;
Ok(parsed)
}
pub async fn complete_streaming_chat(
&self,
messages: &[ChatCompletionRequestMessage],
tools: Option<&[ChatCompletionTools]>,
) -> Result<ChatCompletionStream> {
if messages.is_empty() {
return Err(FoundryLocalError::Validation {
reason: "messages must be a non-empty array".into(),
});
}
let request = self.build_request(messages, tools, true)?;
let params = json!({
"Params": {
"OpenAICreateRequest": serde_json::to_string(&request)?
}
});
let rx = self
.core
.execute_command_streaming_channel("chat_completions".into(), Some(params))
.await?;
Ok(ChatCompletionStream::new(rx))
}
fn build_request(
&self,
messages: &[ChatCompletionRequestMessage],
tools: Option<&[ChatCompletionTools]>,
stream: bool,
) -> Result<Value> {
let settings_value = self.settings.serialize();
let mut map = match settings_value {
Value::Object(m) => m,
_ => serde_json::Map::new(),
};
map.insert("model".into(), json!(self.model_id));
map.insert("messages".into(), serde_json::to_value(messages)?);
if stream {
map.insert("stream".into(), json!(true));
}
if let Some(t) = tools {
map.insert("tools".into(), serde_json::to_value(t)?);
}
Ok(Value::Object(map))
}
}