use std::sync::Arc;
use std::time::Duration;
use futures::StreamExt;
use futures::future::BoxFuture;
use crate::error::BoxError;
use crate::llm::{
CompletionRequest, LlmProvider, Message, MessageContent, ProviderChunk, Role, SamplingParams,
ToolChoice,
};
use super::{HookCtx, HookError, StepHandler};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PromptRender {
Json,
Template { template: String },
}
#[derive(Clone)]
pub struct PromptSpec {
pub provider: Arc<dyn LlmProvider>,
pub model: Option<String>,
pub fallback_model: String,
pub system: String,
pub render: PromptRender,
pub timeout_sec: Option<u64>,
}
impl std::fmt::Debug for PromptSpec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PromptSpec")
.field("provider", &self.provider.info())
.field("model", &self.model)
.field("fallback_model", &self.fallback_model)
.field("system", &self.system)
.field("render", &self.render)
.field("timeout_sec", &self.timeout_sec)
.finish()
}
}
pub struct PromptHandler {
spec: PromptSpec,
}
impl PromptHandler {
#[must_use]
pub fn new(spec: PromptSpec) -> Self {
Self { spec }
}
#[must_use]
pub fn timeout(&self) -> Option<Duration> {
self.spec.timeout_sec.map(Duration::from_secs)
}
}
impl StepHandler for PromptHandler {
fn handle_step<'a>(
&'a self,
envelope: &'a serde_json::Value,
ctx: HookCtx<'a>,
) -> BoxFuture<'a, Result<Option<serde_json::Value>, HookError>> {
Box::pin(async move {
let user_text = render_envelope(envelope, &self.spec.render);
let request = CompletionRequest {
model: self
.spec
.model
.clone()
.unwrap_or_else(|| self.spec.fallback_model.clone()),
system: Some(Arc::from(self.spec.system.as_str())),
messages: vec![Message {
role: Role::User,
content: Arc::from([MessageContent::Text { text: user_text }]),
}],
tools: Vec::new(),
tool_choice: ToolChoice::None,
sampling: SamplingParams::default(),
hosted_capabilities: Default::default(),
};
let stream = self
.spec
.provider
.complete(request, ctx.cancel.clone())
.await
.map_err(|err| HookError::HandlerFailed(BoxError::new(err)))?;
let text = collect_text(stream).await?;
if text.is_empty() {
return Ok(None);
}
Ok(Some(serde_json::json!({ "additional_context": [text] })))
})
}
}
fn render_envelope(envelope: &serde_json::Value, render: &PromptRender) -> String {
match render {
PromptRender::Json => serde_json::to_string(envelope).unwrap_or_default(),
PromptRender::Template { template } => {
let mut out = String::with_capacity(template.len());
let mut rest = template.as_str();
while let Some(start) = rest.find("{{") {
let Some((head, tail)) = rest.split_at_checked(start) else {
break;
};
out.push_str(head);
let Some(after_open) = tail.get(2..) else {
break;
};
let Some(close) = after_open.find("}}") else {
out.push_str(tail);
return out;
};
let Some(key) = after_open.get(..close).map(str::trim) else {
break;
};
match envelope.get(key) {
Some(serde_json::Value::String(s)) => out.push_str(s),
Some(other) => out.push_str(&other.to_string()),
None => {}
}
rest = match after_open.get(close + 2..) {
Some(s) => s,
None => break,
};
}
out.push_str(rest);
out
}
}
}
async fn collect_text(mut stream: crate::llm::ProviderStream) -> Result<String, HookError> {
let mut out = String::new();
while let Some(chunk) = stream.next().await {
match chunk {
Ok(ProviderChunk::TextDelta { text }) => out.push_str(&text),
Ok(ProviderChunk::Stop { .. }) => break,
Ok(_) => {} Err(err) => {
return Err(HookError::HandlerFailed(BoxError::new(err)));
}
}
}
Ok(out)
}
#[cfg(test)]
mod tests;