use std::hash::{DefaultHasher, Hash, Hasher};
use chat_core::error::{ChatError, ChatFailure};
use chat_core::types::messages::Messages;
use chat_core::types::messages::content::RoleEnum;
use chat_core::types::messages::parts::PartEnum;
use chat_core::types::options::ChatOptions;
use crate::client::{Config, Sampling};
use super::{SessionConfig, TurnRequest, WireOptions};
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct ConvoEntry {
pub role: &'static str,
pub text: String,
}
pub(crate) fn prepare(
messages: &Messages,
structured_output: Option<&schemars::Schema>,
tools_present: bool,
) -> Result<(Option<String>, Vec<ConvoEntry>), ChatFailure> {
if tools_present {
return Err(unsupported("tool declarations"));
}
if structured_output.is_some() {
return Err(unsupported("structured outputs"));
}
let mut instructions = String::new();
let mut convo = Vec::new();
for content in &messages.0 {
let text = flatten_text_only(&content.parts.0)?;
match content.role {
RoleEnum::System => {
if !instructions.is_empty() {
instructions.push('\n');
}
instructions.push_str(&text);
}
RoleEnum::User => convo.push(ConvoEntry { role: "user", text }),
RoleEnum::Model => convo.push(ConvoEntry {
role: "assistant",
text,
}),
}
}
if convo.is_empty() {
return Err(ChatFailure::from_err(ChatError::Provider(
"chat-applefm needs at least one user message".into(),
)));
}
Ok(((!instructions.is_empty()).then_some(instructions), convo))
}
#[derive(Debug, PartialEq)]
pub(crate) enum TurnPlan {
Reuse,
Rebuild,
}
pub(crate) fn hash_instructions(instructions: Option<&str>) -> u64 {
let mut hasher = DefaultHasher::new();
instructions.hash(&mut hasher);
hasher.finish()
}
pub(crate) fn hash_convo(entries: &[ConvoEntry]) -> u64 {
let mut hasher = DefaultHasher::new();
for entry in entries {
entry.role.hash(&mut hasher);
entry.text.hash(&mut hasher);
}
hasher.finish()
}
pub(crate) fn render_full(convo: &[ConvoEntry]) -> String {
if let [only] = convo {
return only.text.clone();
}
convo
.iter()
.map(|entry| {
let tag = if entry.role == "assistant" {
"Assistant"
} else {
"User"
};
format!("{tag}: {}", entry.text)
})
.collect::<Vec<_>>()
.join("\n\n")
}
pub(crate) fn session_config_json(
instructions: Option<&str>,
config: &Config,
) -> Result<String, ChatFailure> {
let session_config = SessionConfig {
instructions: instructions.map(str::to_owned),
lora: config
.lora
.as_ref()
.map(|p| p.to_string_lossy().into_owned()),
};
to_json(&session_config)
}
pub(crate) fn turn_request_json(
message: String,
options: Option<WireOptions>,
) -> Result<String, ChatFailure> {
to_json(&TurnRequest { message, options })
}
fn to_json<T: serde::Serialize>(value: &T) -> Result<String, ChatFailure> {
serde_json::to_string(value)
.map_err(|e| ChatFailure::from_err(ChatError::Other(format!("request serialization: {e}"))))
}
pub(crate) fn merge_options(config: &Config, opts: Option<&ChatOptions>) -> Option<WireOptions> {
let mut wire = WireOptions {
temperature: config.temperature,
max_tokens: config.max_tokens,
..Default::default()
};
match config.sampling {
Some(Sampling::Greedy) => wire.greedy = Some(true),
Some(Sampling::TopK { k, seed }) => (wire.top_k, wire.seed) = (Some(k), seed),
Some(Sampling::TopP { p, seed }) => (wire.top_p, wire.seed) = (Some(p), seed),
None => {}
}
if let Some(opts) = opts {
if let Some(t) = opts.temperature {
wire.temperature = Some(f64::from(t));
}
if let Some(m) = opts.max_tokens {
wire.max_tokens = Some(m);
}
let greedy = opts.metadata.get("greedy").and_then(|v| v.as_bool());
let top_k = opts.metadata.get("top_k").and_then(|v| v.as_u64());
let seed = opts.metadata.get("seed").and_then(|v| v.as_u64());
if greedy.is_some() || top_k.is_some() || opts.top_p.is_some() || seed.is_some() {
(wire.greedy, wire.top_k, wire.top_p, wire.seed) = (
greedy,
top_k.map(|k| k as u32),
opts.top_p.map(f64::from),
seed,
);
}
}
let is_empty = matches!(
wire,
WireOptions {
temperature: None,
max_tokens: None,
greedy: None,
top_k: None,
top_p: None,
seed: None,
}
);
(!is_empty).then_some(wire)
}
fn flatten_text_only(parts: &[PartEnum]) -> Result<String, ChatFailure> {
let mut buf = String::new();
for part in parts {
match part {
PartEnum::Text(t) => {
if !buf.is_empty() {
buf.push('\n');
}
buf.push_str(t.as_str());
}
PartEnum::File(f) => {
return Err(unsupported(&format!("file parts (mimetype {})", f.mime)));
}
PartEnum::Tool(_) => return Err(unsupported("tool parts")),
PartEnum::Structured(_) => return Err(unsupported("structured parts in input")),
PartEnum::Reasoning(_) => return Err(unsupported("reasoning parts in input")),
PartEnum::Embeddings(_) => return Err(unsupported("embedding parts in input")),
}
}
Ok(buf)
}
fn unsupported(what: &str) -> ChatFailure {
ChatFailure::from_err(ChatError::Provider(format!(
"chat-applefm does not yet support {what}"
)))
}
#[cfg(test)]
mod tests {
use super::*;
use chat_core::parts;
use chat_core::types::messages::content;
fn entry(role: &'static str, text: &str) -> ConvoEntry {
ConvoEntry {
role,
text: text.to_owned(),
}
}
#[test]
fn folds_system_into_instructions() {
let mut messages = Messages::default();
messages.push(content::from_system(parts!["Talk like a pirate."]));
messages.push(content::from_user(parts!["hello"]));
let (instructions, convo) = prepare(&messages, None, false).unwrap();
assert_eq!(instructions.as_deref(), Some("Talk like a pirate."));
assert_eq!(convo, vec![entry("user", "hello")]);
}
#[test]
fn rejects_tools_and_structured() {
let mut messages = Messages::default();
messages.push(content::from_user(parts!["hi"]));
assert!(prepare(&messages, None, true).is_err());
let schema = schemars::json_schema!({"type": "object"});
assert!(prepare(&messages, Some(&schema), false).is_err());
}
#[test]
fn renders_single_and_multi_turn() {
assert_eq!(render_full(&[entry("user", "hi")]), "hi");
assert_eq!(
render_full(&[entry("user", "hi"), entry("assistant", "yo")]),
"User: hi\n\nAssistant: yo"
);
}
#[test]
fn builder_defaults_yield_to_call_options_as_a_family() {
let config = Config {
lora: None,
temperature: Some(0.7),
max_tokens: Some(100),
sampling: Some(Sampling::TopK {
k: 40,
seed: Some(7),
}),
};
let wire = merge_options(&config, None).unwrap();
assert_eq!(wire.top_k, Some(40));
assert_eq!(wire.seed, Some(7));
let mut opts = ChatOptions::default();
opts.top_p = Some(0.75);
opts.temperature = Some(0.5);
let wire = merge_options(&config, Some(&opts)).unwrap();
assert_eq!(wire.top_p, Some(0.75));
assert_eq!(wire.top_k, None);
assert_eq!(wire.seed, None);
assert_eq!(wire.max_tokens, Some(100));
assert_eq!(wire.temperature, Some(0.5));
}
#[test]
fn session_config_carries_lora() {
let config = Config {
lora: Some("adapters/transcripts.fmadapter".into()),
..Default::default()
};
let json = session_config_json(Some("sys"), &config).unwrap();
let v: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(v["instructions"], "sys");
assert_eq!(v["lora"], "adapters/transcripts.fmadapter");
}
}