defect_agent/hooks/prompt.rs
1//! Prompt hook handler — feeds a step envelope into a single LLM call.
2
3//! ## Not counted in the main loop's LLM call count
4//!
5//! The handler directly uses an [`Arc<dyn LlmProvider>`] to call
6//! [`LlmProvider::complete`],
7//! **without entering the history or counting toward `turn_request_count`** — this
8//! prevents
9//! a `SessionStart` hook from consuming one of the user's `max_turn_requests`.
10//!
11//! ## No nested Prompt handlers
12//!
13//! Internal LLM calls must not emit hook events, to avoid
14//! infinite recursion. This is guaranteed by the caller (the hook engine) — events
15//! entered via `fire` will not trigger hooks again due to LLM calls made inside the
16//! handler (there is no back-channel between the hook engine and the LLM provider;
17//! `provider.complete` is unaware of the hook system). No additional protection is
18//! needed on the handler side.
19//!
20//! ## Cold-start degradation
21//!
22//! If the LLM call on `SessionStart` fails, degrade per the degradation table — `SessionStart`
23//! must not block; errors are downgraded to warnings and the pipeline continues. This
24//! invariant is enforced by [`super::DefaultHookEngine`]; the handler only needs to
25//! propagate the error faithfully.
26//!
27//! [`LlmProvider::complete`]: crate::llm::LlmProvider::complete
28
29use std::sync::Arc;
30use std::time::Duration;
31
32use futures::StreamExt;
33use futures::future::BoxFuture;
34
35use crate::error::BoxError;
36use crate::llm::{
37 CompletionRequest, LlmProvider, Message, MessageContent, ProviderChunk, Role, SamplingParams,
38 ToolChoice,
39};
40
41use super::{HookCtx, HookError, StepHandler};
42
43/// Template rendering strategy.
44///
45/// The `Template` variant performs simple `{{key}}` string substitution without
46/// introducing heavy dependencies like handlebars or tera. Recognized keys are
47/// documented in the `render_envelope` implementation:
48/// - All events: `{{event}}` / `{{cwd}}` / `{{session_id}}`
49/// - PreToolUse / Post*: `{{tool}}` / `{{tool_input}}` / `{{tool_error}}`
50/// - UserPromptSubmit: `{{prompt}}`
51/// - SessionStart: `{{session_source}}`
52///
53/// Unrecognized keys are replaced with an empty string (conservative semantics to
54/// avoid sending raw `{{...}}` to the model when the template is misconfigured).
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum PromptRender {
57 /// Feeds the JSON-serialized step envelope directly.
58 Json,
59 /// Replaces `{{key}}` placeholders with values from the event fields.
60 Template { template: String },
61}
62
63/// Configuration for the prompt handler.
64#[derive(Clone)]
65pub struct PromptSpec {
66 pub provider: Arc<dyn LlmProvider>,
67 /// `None` = use [`Self::fallback_model`] (the session default model).
68 pub model: Option<String>,
69 /// Used when `model` is `None` — the CLI assembly phase feeds in `TurnConfig::model`.
70 pub fallback_model: String,
71 pub system: String,
72 pub render: PromptRender,
73 pub timeout_sec: Option<u64>,
74}
75
76impl std::fmt::Debug for PromptSpec {
77 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
78 f.debug_struct("PromptSpec")
79 .field("provider", &self.provider.info())
80 .field("model", &self.model)
81 .field("fallback_model", &self.fallback_model)
82 .field("system", &self.system)
83 .field("render", &self.render)
84 .field("timeout_sec", &self.timeout_sec)
85 .finish()
86 }
87}
88
89/// Implementation of the `Prompt` handler.
90pub struct PromptHandler {
91 spec: PromptSpec,
92}
93
94impl PromptHandler {
95 #[must_use]
96 pub fn new(spec: PromptSpec) -> Self {
97 Self { spec }
98 }
99
100 #[must_use]
101 pub fn timeout(&self) -> Option<Duration> {
102 self.spec.timeout_sec.map(Duration::from_secs)
103 }
104}
105
106impl StepHandler for PromptHandler {
107 /// Renders the step envelope into user text (for JSON mode, serializes the envelope
108 /// directly; for Template mode, extracts top-level fields using `{{key}}`), runs one
109 /// LLM call, and uses the output text as the `additional_context` verdict.
110 fn handle_step<'a>(
111 &'a self,
112 envelope: &'a serde_json::Value,
113 ctx: HookCtx<'a>,
114 ) -> BoxFuture<'a, Result<Option<serde_json::Value>, HookError>> {
115 Box::pin(async move {
116 let user_text = render_envelope(envelope, &self.spec.render);
117 let request = CompletionRequest {
118 model: self
119 .spec
120 .model
121 .clone()
122 .unwrap_or_else(|| self.spec.fallback_model.clone()),
123 system: Some(Arc::from(self.spec.system.as_str())),
124 messages: vec![Message {
125 role: Role::User,
126 content: Arc::from([MessageContent::Text { text: user_text }]),
127 }],
128 tools: Vec::new(),
129 tool_choice: ToolChoice::None,
130 sampling: SamplingParams::default(),
131 hosted_capabilities: Default::default(),
132 };
133 let stream = self
134 .spec
135 .provider
136 .complete(request, ctx.cancel.clone())
137 .await
138 .map_err(|err| HookError::HandlerFailed(BoxError::new(err)))?;
139 let text = collect_text(stream).await?;
140 if text.is_empty() {
141 return Ok(None);
142 }
143 Ok(Some(serde_json::json!({ "additional_context": [text] })))
144 })
145 }
146}
147
148/// Renders the envelope: `Json` serializes it; `Template` replaces `{{key}}` with the
149/// top-level field value (strings and numbers are converted to text directly).
150fn render_envelope(envelope: &serde_json::Value, render: &PromptRender) -> String {
151 match render {
152 PromptRender::Json => serde_json::to_string(envelope).unwrap_or_default(),
153 PromptRender::Template { template } => {
154 let mut out = String::with_capacity(template.len());
155 let mut rest = template.as_str();
156 while let Some(start) = rest.find("{{") {
157 let Some((head, tail)) = rest.split_at_checked(start) else {
158 break;
159 };
160 out.push_str(head);
161 let Some(after_open) = tail.get(2..) else {
162 break;
163 };
164 let Some(close) = after_open.find("}}") else {
165 out.push_str(tail);
166 return out;
167 };
168 let Some(key) = after_open.get(..close).map(str::trim) else {
169 break;
170 };
171 match envelope.get(key) {
172 Some(serde_json::Value::String(s)) => out.push_str(s),
173 Some(other) => out.push_str(&other.to_string()),
174 None => {}
175 }
176 rest = match after_open.get(close + 2..) {
177 Some(s) => s,
178 None => break,
179 };
180 }
181 out.push_str(rest);
182 out
183 }
184 }
185}
186
187async fn collect_text(mut stream: crate::llm::ProviderStream) -> Result<String, HookError> {
188 let mut out = String::new();
189 while let Some(chunk) = stream.next().await {
190 match chunk {
191 Ok(ProviderChunk::TextDelta { text }) => out.push_str(&text),
192 Ok(ProviderChunk::Stop { .. }) => break,
193 Ok(_) => {} // Ignore thinking, tool_use, usage, etc.
194 Err(err) => {
195 return Err(HookError::HandlerFailed(BoxError::new(err)));
196 }
197 }
198 }
199 Ok(out)
200}
201
202#[cfg(test)]
203mod tests;