use std::path::PathBuf;
use std::sync::Arc;
use chat_core::types::provider_meta::ProviderMeta;
use crate::api::types::request::{ConvoEntry, TurnPlan, hash_convo};
use crate::ffi;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Sampling {
Greedy,
TopK { k: u32, seed: Option<u64> },
TopP { p: f64, seed: Option<u64> },
}
#[derive(Debug, Default)]
pub(crate) struct Config {
pub(crate) lora: Option<PathBuf>,
pub(crate) temperature: Option<f64>,
pub(crate) max_tokens: Option<u32>,
pub(crate) sampling: Option<Sampling>,
}
#[derive(Debug)]
struct SessionHandle(u64);
impl Drop for SessionHandle {
fn drop(&mut self) {
ffi::session_free(self.0);
}
}
#[derive(Debug)]
struct SessionState {
handle: SessionHandle,
instructions_hash: u64,
prefix_hash: u64,
prefix_len: usize,
}
#[derive(Debug, Default)]
pub(crate) struct Session(Option<SessionState>);
impl Session {
pub(crate) fn plan(&self, instructions_hash: u64, convo: &[ConvoEntry]) -> TurnPlan {
match &self.0 {
Some(s)
if s.instructions_hash == instructions_hash
&& convo.len() == s.prefix_len + 1
&& hash_convo(&convo[..s.prefix_len]) == s.prefix_hash =>
{
TurnPlan::Reuse
}
_ => TurnPlan::Rebuild,
}
}
pub(crate) fn id(&self) -> Option<u64> {
self.0.as_ref().map(|s| s.handle.0)
}
pub(crate) fn invalidate(&mut self) {
self.0 = None;
}
pub(crate) fn install(&mut self, id: u64, instructions_hash: u64) {
self.0 = Some(SessionState {
handle: SessionHandle(id),
instructions_hash,
prefix_hash: hash_convo(&[]),
prefix_len: 0,
});
}
pub(crate) fn advance(&mut self, mut convo: Vec<ConvoEntry>, reply_text: String) {
if let Some(s) = &mut self.0 {
convo.push(ConvoEntry {
role: "assistant",
text: reply_text,
});
s.prefix_len = convo.len();
s.prefix_hash = hash_convo(&convo);
}
}
}
#[derive(Clone, Debug)]
pub struct AppleFMClient {
pub(crate) config: Arc<Config>,
pub(crate) meta: Arc<ProviderMeta>,
pub(crate) session: Arc<tokio::sync::Mutex<Session>>,
}
impl AppleFMClient {
pub fn model_slug(&self) -> String {
match self.config.lora.as_deref().and_then(|p| p.file_stem()) {
Some(stem) => format!("apple-on-device+{}", stem.to_string_lossy()),
None => "apple-on-device".to_owned(),
}
}
pub fn provider_meta(&self) -> &ProviderMeta {
&self.meta
}
pub fn prewarm(&self) {
if let Ok(session) = self.session.try_lock() {
let id = session.id().unwrap_or(0);
drop(session);
ffi::prewarm(id);
}
}
pub(crate) fn enrich_metadata(
&self,
metadata: &mut chat_core::types::metadata::Metadata,
elapsed: std::time::Duration,
reused_session: bool,
) {
metadata.duration_ms = Some(elapsed.as_millis() as u64);
metadata.created_at = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.ok()
.map(|d| d.as_secs());
metadata.provider_specific.insert(
"prefill".to_owned(),
serde_json::Value::String(
if reused_session {
"incremental"
} else {
"full"
}
.to_owned(),
),
);
if let Some(lora) = &self.config.lora {
metadata.provider_specific.insert(
"lora".to_owned(),
serde_json::Value::String(lora.to_string_lossy().into_owned()),
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api::types::request::hash_instructions;
fn entry(role: &'static str, text: &str) -> ConvoEntry {
ConvoEntry {
role,
text: text.to_owned(),
}
}
#[test]
fn session_lifecycle_plans_correctly() {
let instructions_hash = hash_instructions(Some("sys"));
let mut session = Session::default();
let convo1 = vec![entry("user", "hi")];
assert_eq!(session.plan(instructions_hash, &convo1), TurnPlan::Rebuild);
session.install(1, instructions_hash);
session.advance(convo1, "yo".to_owned());
let convo2 = vec![
entry("user", "hi"),
entry("assistant", "yo"),
entry("user", "how are you?"),
];
assert_eq!(session.plan(instructions_hash, &convo2), TurnPlan::Reuse);
let edited = vec![
entry("user", "hi EDITED"),
entry("assistant", "yo"),
entry("user", "how are you?"),
];
assert_eq!(session.plan(instructions_hash, &edited), TurnPlan::Rebuild);
assert_eq!(
session.plan(hash_instructions(Some("other")), &convo2),
TurnPlan::Rebuild
);
session.invalidate();
assert_eq!(session.id(), None);
assert_eq!(session.plan(instructions_hash, &convo2), TurnPlan::Rebuild);
}
}