use std::marker::PhantomData;
use serde::Serialize;
use super::{NonStreaming, Streaming};
pub use crate::domain::llm::{ContentPart, FileSearchFilter, LlmMessage, LlmTool, Role};
#[derive(Debug, Clone)]
pub struct UserIdentity {
pub tenant_id: String,
pub user_id: String,
}
#[derive(Debug, Clone, Serialize)]
pub struct RequestMetadata {
pub tenant_id: String,
pub user_id: String,
pub chat_id: String,
pub request_type: RequestType,
#[serde(rename = "feature", serialize_with = "serialize_feature")]
pub features: Vec<FeatureFlag>,
}
fn serialize_feature<S: serde::Serializer>(
features: &[FeatureFlag],
serializer: S,
) -> Result<S::Ok, S::Error> {
if features.is_empty() {
return serializer.serialize_str("none");
}
let s: String = features
.iter()
.copied()
.map(FeatureFlag::as_str)
.collect::<Vec<_>>()
.join("+");
serializer.serialize_str(&s)
}
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "snake_case")]
pub enum RequestType {
Chat,
Summary,
DocSummary,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FeatureFlag {
FileSearch,
WebSearch,
CodeInterpreter,
}
impl FeatureFlag {
fn as_str(self) -> &'static str {
match self {
Self::FileSearch => "file_search",
Self::WebSearch => "web_search",
Self::CodeInterpreter => "code_interpreter",
}
}
}
pub struct LlmRequest<Mode = Streaming> {
pub(crate) model: String,
pub(crate) messages: Vec<LlmMessage>,
pub(crate) system_instructions: Option<String>,
pub(crate) max_output_tokens: Option<u64>,
pub(crate) tools: Vec<LlmTool>,
pub(crate) user_identity: Option<UserIdentity>,
pub(crate) metadata: Option<RequestMetadata>,
pub(crate) max_tool_calls: Option<u32>,
pub(crate) additional_params: Option<serde_json::Value>,
pub(crate) _mode: PhantomData<Mode>,
}
impl<M> LlmRequest<M> {
#[must_use]
pub fn model(&self) -> &str {
&self.model
}
#[must_use]
pub fn messages(&self) -> &[LlmMessage] {
&self.messages
}
#[must_use]
pub fn tools(&self) -> &[LlmTool] {
&self.tools
}
}
pub struct LlmRequestBuilder {
model: String,
messages: Vec<LlmMessage>,
system_instructions: Option<String>,
max_output_tokens: Option<u64>,
tools: Vec<LlmTool>,
user_identity: Option<UserIdentity>,
metadata: Option<RequestMetadata>,
max_tool_calls: Option<u32>,
additional_params: Option<serde_json::Value>,
}
impl LlmRequestBuilder {
#[must_use]
pub fn new(model: impl Into<String>) -> Self {
LlmRequestBuilder {
model: model.into(),
messages: Vec::new(),
system_instructions: None,
max_output_tokens: None,
tools: Vec::new(),
user_identity: None,
metadata: None,
max_tool_calls: None,
additional_params: None,
}
}
#[must_use]
pub fn message(mut self, message: LlmMessage) -> Self {
self.messages.push(message);
self
}
#[must_use]
pub fn messages(mut self, messages: Vec<LlmMessage>) -> Self {
self.messages = messages;
self
}
#[must_use]
pub fn system_instructions(mut self, instructions: impl Into<String>) -> Self {
self.system_instructions = Some(instructions.into());
self
}
#[must_use]
pub fn max_output_tokens(mut self, max_tokens: u64) -> Self {
self.max_output_tokens = Some(max_tokens);
self
}
#[must_use]
pub fn tool(mut self, tool: LlmTool) -> Self {
self.tools.push(tool);
self
}
#[must_use]
pub fn tools(mut self, tools: Vec<LlmTool>) -> Self {
self.tools = tools;
self
}
#[must_use]
pub fn user_identity(
mut self,
tenant_id: impl Into<String>,
user_id: impl Into<String>,
) -> Self {
self.user_identity = Some(UserIdentity {
tenant_id: tenant_id.into(),
user_id: user_id.into(),
});
self
}
#[must_use]
pub fn metadata(mut self, metadata: RequestMetadata) -> Self {
self.metadata = Some(metadata);
self
}
#[must_use]
pub fn max_tool_calls(mut self, max: u32) -> Self {
self.max_tool_calls = Some(max);
self
}
#[must_use]
pub fn additional_params(mut self, params: serde_json::Value) -> Self {
self.additional_params = Some(params);
self
}
fn build_inner<Mode>(self) -> LlmRequest<Mode> {
LlmRequest {
model: self.model,
messages: self.messages,
system_instructions: self.system_instructions,
max_output_tokens: self.max_output_tokens,
tools: self.tools,
user_identity: self.user_identity,
metadata: self.metadata,
max_tool_calls: self.max_tool_calls,
additional_params: self.additional_params,
_mode: PhantomData,
}
}
#[must_use]
pub fn build_streaming(self) -> LlmRequest<Streaming> {
self.build_inner()
}
#[must_use]
pub fn build_non_streaming(self) -> LlmRequest<NonStreaming> {
self.build_inner()
}
}