use std::sync::Arc;
use oxillama_runtime::sampling::SamplerConfig;
use tokio::sync::oneshot;
pub type VocabBytes = Arc<Vec<(u32, Vec<u8>)>>;
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct UsageStats {
pub prompt_tokens: usize,
pub completion_tokens: usize,
pub total_tokens: usize,
}
pub type StreamCallback = Box<dyn FnMut(&str) + Send>;
pub type LoraSelection = Vec<(String, f32)>;
pub enum BatchRequest {
Generate {
prompt: String,
max_tokens: usize,
config: SamplerConfig,
cache_prompt: bool,
lora_selection: LoraSelection,
reply: oneshot::Sender<Result<(String, UsageStats), String>>,
},
GenerateStream {
prompt: String,
max_tokens: usize,
config: SamplerConfig,
cache_prompt: bool,
lora_selection: LoraSelection,
callback: StreamCallback,
reply: oneshot::Sender<Result<UsageStats, String>>,
},
Embed {
text: String,
reply: oneshot::Sender<Result<Vec<f32>, String>>,
},
}
impl std::fmt::Debug for BatchRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
BatchRequest::Generate {
prompt,
max_tokens,
cache_prompt,
lora_selection,
..
} => f
.debug_struct("Generate")
.field("prompt_len", &prompt.len())
.field("max_tokens", max_tokens)
.field("cache_prompt", cache_prompt)
.field("lora_count", &lora_selection.len())
.finish(),
BatchRequest::GenerateStream {
prompt,
max_tokens,
cache_prompt,
lora_selection,
..
} => f
.debug_struct("GenerateStream")
.field("prompt_len", &prompt.len())
.field("max_tokens", max_tokens)
.field("cache_prompt", cache_prompt)
.field("lora_count", &lora_selection.len())
.finish(),
BatchRequest::Embed { text, .. } => f
.debug_struct("Embed")
.field("text_len", &text.len())
.finish(),
}
}
}
#[derive(Debug, Clone)]
pub struct ModelMeta {
pub default_sampler: SamplerConfig,
pub vocab_bytes: Option<VocabBytes>,
pub hidden_size: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::sync::oneshot;
#[tokio::test]
async fn test_generate_round_trip() {
let (tx, mut rx) = tokio::sync::mpsc::channel::<BatchRequest>(8);
let (reply_tx, reply_rx) = oneshot::channel::<Result<(String, UsageStats), String>>();
tx.send(BatchRequest::Generate {
prompt: "hello".to_string(),
max_tokens: 16,
config: SamplerConfig::default(),
cache_prompt: true,
lora_selection: vec![],
reply: reply_tx,
})
.await
.expect("channel should accept the request");
let req = rx.recv().await.expect("worker should receive request");
match req {
BatchRequest::Generate {
prompt,
max_tokens,
reply,
..
} => {
assert_eq!(prompt, "hello");
assert_eq!(max_tokens, 16);
let usage = UsageStats {
prompt_tokens: 1,
completion_tokens: 1,
total_tokens: 2,
};
reply
.send(Ok(("world".to_string(), usage)))
.expect("reply should succeed");
}
other => panic!("unexpected variant: {other:?}"),
}
let result = reply_rx.await.expect("reply future should resolve");
let (text, usage) = result.expect("should be Ok");
assert_eq!(text, "world");
assert_eq!(usage.total_tokens, 2);
}
#[test]
fn test_debug_does_not_expose_full_prompt() {
let (reply_tx, _reply_rx) = oneshot::channel::<Result<(String, UsageStats), String>>();
let req = BatchRequest::Generate {
prompt: "secret prompt contents".to_string(),
max_tokens: 32,
config: SamplerConfig::default(),
cache_prompt: true,
lora_selection: vec![],
reply: reply_tx,
};
let debug_str = format!("{req:?}");
assert!(debug_str.contains("prompt_len"));
assert!(!debug_str.contains("secret prompt contents"));
}
}