#[cfg(target_arch = "wasm32")]
use crate::tokio;
use async_trait::async_trait;
use futures::StreamExt;
use meerkat_core::lifecycle::run_primitive::{ProviderParamsOverride, ProviderTag};
use meerkat_core::schema::{CompiledSchema, SchemaError};
use meerkat_core::{
AgentError, AgentEvent, AgentLlmClient, LlmStreamResult, Message, OutputSchema, StopReason,
ToolDef, Usage,
};
use std::sync::Arc;
use tokio::sync::mpsc;
use crate::block_assembler::BlockAssembler;
use crate::types::{LlmClient, LlmDoneOutcome, LlmEvent, LlmRequest};
#[derive(Clone)]
pub struct LlmClientAdapter {
client: Arc<dyn LlmClient>,
model: String,
event_tx: Option<mpsc::Sender<AgentEvent>>,
provider_params: Option<ProviderTag>,
event_tap: meerkat_core::EventTap,
}
impl LlmClientAdapter {
pub fn new(client: Arc<dyn LlmClient>, model: String) -> Self {
Self {
client,
model,
event_tx: None,
provider_params: None,
event_tap: meerkat_core::new_event_tap(),
}
}
pub fn with_event_channel(
client: Arc<dyn LlmClient>,
model: String,
event_tx: mpsc::Sender<AgentEvent>,
) -> Self {
Self {
client,
model,
event_tx: Some(event_tx),
provider_params: None,
event_tap: meerkat_core::new_event_tap(),
}
}
pub fn with_provider_params(mut self, params: Option<ProviderTag>) -> Self {
self.provider_params = params;
self
}
pub fn with_event_tap(mut self, tap: meerkat_core::EventTap) -> Self {
self.event_tap = tap;
self
}
fn strip_non_object_provider_tool_overrides(tag: ProviderTag) -> ProviderTag {
match tag {
ProviderTag::Anthropic(mut tag) => {
if tag
.web_search
.as_ref()
.is_some_and(|body| !body.as_value().is_object())
{
tag.web_search = None;
}
ProviderTag::Anthropic(tag)
}
ProviderTag::OpenAi(mut tag) => {
if tag
.web_search
.as_ref()
.is_some_and(|body| !body.as_value().is_object())
{
tag.web_search = None;
}
ProviderTag::OpenAi(tag)
}
ProviderTag::Gemini(mut tag) => {
if tag
.google_search
.as_ref()
.is_some_and(|body| !body.as_value().is_object())
{
tag.google_search = None;
}
ProviderTag::Gemini(tag)
}
other => other,
}
}
fn apply_generic_provider_overrides(
&self,
tag: Option<ProviderTag>,
params: Option<&ProviderParamsOverride>,
) -> Option<ProviderTag> {
let Some(params) = params else {
return tag;
};
match self.client.provider() {
"anthropic" if params.thinking_budget_tokens.is_some() => match tag {
Some(ProviderTag::Anthropic(mut tag)) => {
tag.thinking_budget_tokens = params.thinking_budget_tokens;
Some(ProviderTag::Anthropic(tag))
}
None => Some(ProviderTag::Anthropic(
meerkat_core::lifecycle::run_primitive::AnthropicProviderTag {
thinking_budget_tokens: params.thinking_budget_tokens,
..Default::default()
},
)),
other => other,
},
"gemini" | "google"
if params.top_p.is_some() || params.thinking_budget_tokens.is_some() =>
{
match tag {
Some(ProviderTag::Gemini(mut tag)) => {
if let Some(top_p) = params.top_p {
tag.top_p = Some(top_p);
}
if let Some(budget) = params.thinking_budget_tokens {
tag.thinking_budget = Some(budget);
}
Some(ProviderTag::Gemini(tag))
}
None => Some(ProviderTag::Gemini(
meerkat_core::lifecycle::run_primitive::GeminiProviderTag {
top_p: params.top_p,
thinking_budget: params.thinking_budget_tokens,
..Default::default()
},
)),
other => other,
}
}
_ => tag,
}
}
}
#[allow(clippy::unwrap_used, clippy::expect_used)]
fn fallback_raw_value() -> Box<serde_json::value::RawValue> {
serde_json::value::RawValue::from_string("{}".to_string()).expect("static JSON is valid")
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl AgentLlmClient for LlmClientAdapter {
async fn stream_response(
&self,
messages: &[Message],
tools: &[Arc<ToolDef>],
max_tokens: u32,
temperature: Option<f32>,
provider_params: Option<&ProviderParamsOverride>,
) -> Result<LlmStreamResult, AgentError> {
let effective_params = provider_params
.and_then(|params| params.provider_tag.clone())
.or_else(|| self.provider_params.clone());
let effective_params =
self.apply_generic_provider_overrides(effective_params, provider_params);
let effective_params = effective_params.map(Self::strip_non_object_provider_tool_overrides);
let effective_max_tokens = provider_params
.and_then(|params| params.max_output_tokens)
.unwrap_or(max_tokens);
let effective_temperature = provider_params
.and_then(|params| params.temperature)
.or(temperature);
let request = LlmRequest {
model: self.model.clone(),
messages: messages.to_vec(),
tools: tools.to_vec(),
max_tokens: effective_max_tokens,
temperature: effective_temperature,
stop_sequences: None,
provider_params: effective_params,
};
let mut stream = self.client.stream(&request);
let mut assembler = BlockAssembler::new();
let mut reasoning_started = false;
let mut stop_reason = StopReason::EndTurn;
let mut usage = Usage::default();
while let Some(result) = stream.next().await {
match result {
Ok(event) => match event {
LlmEvent::TextDelta { delta, meta } => {
assembler.on_text_delta(&delta, meta);
meerkat_core::tap_try_send(
&self.event_tap,
&AgentEvent::TextDelta {
delta: delta.clone(),
},
);
if let Some(ref tx) = self.event_tx {
let _ = tx.send(AgentEvent::TextDelta { delta }).await;
}
}
LlmEvent::ReasoningDelta { delta } => {
if !reasoning_started {
reasoning_started = true;
assembler.on_reasoning_start();
}
if let Err(e) = assembler.on_reasoning_delta(&delta) {
tracing::warn!(?e, "orphaned reasoning delta");
}
meerkat_core::tap_try_send(
&self.event_tap,
&AgentEvent::ReasoningDelta {
delta: delta.clone(),
},
);
if let Some(ref tx) = self.event_tx {
let _ = tx.send(AgentEvent::ReasoningDelta { delta }).await;
}
}
LlmEvent::ReasoningComplete { text, meta } => {
if !reasoning_started {
assembler.on_reasoning_start();
let _ = assembler.on_reasoning_delta(&text);
}
let reasoning_text = assembler.current_reasoning_text();
assembler.on_reasoning_complete(meta);
reasoning_started = false;
meerkat_core::tap_try_send(
&self.event_tap,
&AgentEvent::ReasoningComplete {
content: reasoning_text.clone(),
},
);
if let Some(ref tx) = self.event_tx {
let _ = tx
.send(AgentEvent::ReasoningComplete {
content: reasoning_text,
})
.await;
}
}
LlmEvent::ToolCallDelta {
id,
name,
args_delta,
} => {
if let Err(e) =
assembler.on_tool_call_delta(&id, name.as_deref(), &args_delta)
{
if matches!(
e,
crate::block_assembler::StreamAssemblyError::OrphanedToolDelta(_)
) {
let _ = assembler.on_tool_call_start(id.clone());
if let Err(e) =
assembler.on_tool_call_delta(&id, name.as_deref(), &args_delta)
{
tracing::warn!(?e, "orphaned tool delta");
}
} else {
tracing::warn!(?e, "tool delta error");
}
}
}
LlmEvent::ToolCallComplete {
id,
name,
args,
meta,
} => {
let effective_meta = meta;
let args_raw = match serde_json::to_string(&args)
.ok()
.and_then(|s| serde_json::value::RawValue::from_string(s).ok())
{
Some(raw) => raw,
None => fallback_raw_value(),
};
let _ = assembler.on_tool_call_complete(id, name, args_raw, effective_meta);
}
LlmEvent::ServerToolContent {
id,
name,
content,
meta,
} => {
let event_id = id.clone();
assembler.on_server_tool_content(id, name.clone(), content.clone(), meta);
if let Some(ref tx) = self.event_tx {
let _ = tx
.send(AgentEvent::ServerToolContent {
id: event_id,
name,
content,
})
.await;
}
}
LlmEvent::UsageUpdate { usage: u } => {
usage = u;
}
LlmEvent::Done { outcome } => match outcome {
LlmDoneOutcome::Success { stop_reason: sr } => {
stop_reason = sr;
}
LlmDoneOutcome::Error { error } => {
return Err(AgentError::llm(
self.client.provider(),
error.failure_reason(),
error.to_string(),
));
}
},
},
Err(e) => {
return Err(AgentError::llm(
self.client.provider(),
e.failure_reason(),
e.to_string(),
));
}
}
}
Ok(LlmStreamResult::new(
assembler.finalize(),
stop_reason,
usage,
))
}
fn provider(&self) -> &'static str {
self.client.provider()
}
fn model(&self) -> &str {
&self.model
}
fn compile_schema(&self, output_schema: &OutputSchema) -> Result<CompiledSchema, SchemaError> {
self.client.compile_schema(output_schema)
}
}