Skip to main content

defect_agent/hooks/
prompt.rs

1//! Prompt hook handler — feeds a step envelope into a single LLM call.
2
3//! ## Not counted in the main loop's LLM call count
4//!
5//! The handler directly uses an [`Arc<dyn LlmProvider>`] to call
6//! [`LlmProvider::complete`],
7//! **without entering the history or counting toward `turn_request_count`** — this
8//! prevents
9//! a `SessionStart` hook from consuming one of the user's `max_turn_requests`.
10//!
11//! ## No nested Prompt handlers
12//!
13//! Internal LLM calls must not emit hook events, to avoid
14//! infinite recursion. This is guaranteed by the caller (the hook engine) — events
15//! entered via `fire` will not trigger hooks again due to LLM calls made inside the
16//! handler (there is no back-channel between the hook engine and the LLM provider;
17//! `provider.complete` is unaware of the hook system). No additional protection is
18//! needed on the handler side.
19//!
20//! ## Cold-start degradation
21//!
22//! If the LLM call on `SessionStart` fails, degrade per the degradation table — `SessionStart`
23//! must not block; errors are downgraded to warnings and the pipeline continues. This
24//! invariant is enforced by [`super::DefaultHookEngine`]; the handler only needs to
25//! propagate the error faithfully.
26//!
27//! [`LlmProvider::complete`]: crate::llm::LlmProvider::complete
28
29use std::sync::Arc;
30use std::time::Duration;
31
32use futures::StreamExt;
33use futures::future::BoxFuture;
34
35use crate::error::BoxError;
36use crate::llm::{
37    CompletionRequest, LlmProvider, Message, MessageContent, ProviderChunk, Role, SamplingParams,
38    ToolChoice,
39};
40
41use super::{HookCtx, HookError, StepHandler};
42
43/// Template rendering strategy.
44///
45/// The `Template` variant performs simple `{{key}}` string substitution without
46/// introducing heavy dependencies like handlebars or tera. Recognized keys are
47/// documented in the `render_envelope` implementation:
48/// - All events: `{{event}}` / `{{cwd}}` / `{{session_id}}`
49/// - PreToolUse / Post*: `{{tool}}` / `{{tool_input}}` / `{{tool_error}}`
50/// - UserPromptSubmit: `{{prompt}}`
51/// - SessionStart: `{{session_source}}`
52///
53/// Unrecognized keys are replaced with an empty string (conservative semantics to
54/// avoid sending raw `{{...}}` to the model when the template is misconfigured).
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum PromptRender {
57    /// Feeds the JSON-serialized step envelope directly.
58    Json,
59    /// Replaces `{{key}}` placeholders with values from the event fields.
60    Template { template: String },
61}
62
63/// Configuration for the prompt handler.
64#[derive(Clone)]
65pub struct PromptSpec {
66    pub provider: Arc<dyn LlmProvider>,
67    /// `None` = use [`Self::fallback_model`] (the session default model).
68    pub model: Option<String>,
69    /// Used when `model` is `None` — the CLI assembly phase feeds in `TurnConfig::model`.
70    pub fallback_model: String,
71    pub system: String,
72    pub render: PromptRender,
73    pub timeout_sec: Option<u64>,
74}
75
76impl std::fmt::Debug for PromptSpec {
77    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78        f.debug_struct("PromptSpec")
79            .field("provider", &self.provider.info())
80            .field("model", &self.model)
81            .field("fallback_model", &self.fallback_model)
82            .field("system", &self.system)
83            .field("render", &self.render)
84            .field("timeout_sec", &self.timeout_sec)
85            .finish()
86    }
87}
88
89/// Implementation of the `Prompt` handler.
90pub struct PromptHandler {
91    spec: PromptSpec,
92}
93
94impl PromptHandler {
95    #[must_use]
96    pub fn new(spec: PromptSpec) -> Self {
97        Self { spec }
98    }
99
100    #[must_use]
101    pub fn timeout(&self) -> Option<Duration> {
102        self.spec.timeout_sec.map(Duration::from_secs)
103    }
104}
105
106impl StepHandler for PromptHandler {
107    /// Renders the step envelope into user text (for JSON mode, serializes the envelope
108    /// directly; for Template mode, extracts top-level fields using `{{key}}`), runs one
109    /// LLM call, and uses the output text as the `additional_context` verdict.
110    fn handle_step<'a>(
111        &'a self,
112        envelope: &'a serde_json::Value,
113        ctx: HookCtx<'a>,
114    ) -> BoxFuture<'a, Result<Option<serde_json::Value>, HookError>> {
115        Box::pin(async move {
116            let user_text = render_envelope(envelope, &self.spec.render);
117            let request = CompletionRequest {
118                model: self
119                    .spec
120                    .model
121                    .clone()
122                    .unwrap_or_else(|| self.spec.fallback_model.clone()),
123                system: Some(Arc::from(self.spec.system.as_str())),
124                messages: vec![Message {
125                    role: Role::User,
126                    content: Arc::from([MessageContent::Text { text: user_text }]),
127                }],
128                tools: Vec::new(),
129                tool_choice: ToolChoice::None,
130                sampling: SamplingParams::default(),
131                hosted_capabilities: Default::default(),
132            };
133            let stream = self
134                .spec
135                .provider
136                .complete(request, ctx.cancel.clone())
137                .await
138                .map_err(|err| HookError::HandlerFailed(BoxError::new(err)))?;
139            let text = collect_text(stream).await?;
140            if text.is_empty() {
141                return Ok(None);
142            }
143            Ok(Some(serde_json::json!({ "additional_context": [text] })))
144        })
145    }
146}
147
148/// Renders the envelope: `Json` serializes it; `Template` replaces `{{key}}` with the
149/// top-level field value (strings and numbers are converted to text directly).
150fn render_envelope(envelope: &serde_json::Value, render: &PromptRender) -> String {
151    match render {
152        PromptRender::Json => serde_json::to_string(envelope).unwrap_or_default(),
153        PromptRender::Template { template } => {
154            let mut out = String::with_capacity(template.len());
155            let mut rest = template.as_str();
156            while let Some(start) = rest.find("{{") {
157                let Some((head, tail)) = rest.split_at_checked(start) else {
158                    break;
159                };
160                out.push_str(head);
161                let Some(after_open) = tail.get(2..) else {
162                    break;
163                };
164                let Some(close) = after_open.find("}}") else {
165                    out.push_str(tail);
166                    return out;
167                };
168                let Some(key) = after_open.get(..close).map(str::trim) else {
169                    break;
170                };
171                match envelope.get(key) {
172                    Some(serde_json::Value::String(s)) => out.push_str(s),
173                    Some(other) => out.push_str(&other.to_string()),
174                    None => {}
175                }
176                rest = match after_open.get(close + 2..) {
177                    Some(s) => s,
178                    None => break,
179                };
180            }
181            out.push_str(rest);
182            out
183        }
184    }
185}
186
187async fn collect_text(mut stream: crate::llm::ProviderStream) -> Result<String, HookError> {
188    let mut out = String::new();
189    while let Some(chunk) = stream.next().await {
190        match chunk {
191            Ok(ProviderChunk::TextDelta { text }) => out.push_str(&text),
192            Ok(ProviderChunk::Stop { .. }) => break,
193            Ok(_) => {} // Ignore thinking, tool_use, usage, etc.
194            Err(err) => {
195                return Err(HookError::HandlerFailed(BoxError::new(err)));
196            }
197        }
198    }
199    Ok(out)
200}
201
202#[cfg(test)]
203mod tests;