Skip to main content

bob_runtime/
prompt.rs

1//! # Prompt Builder
2//!
3//! Prompt builder — assembles `LlmRequest` from session state, tool
4//! descriptors, and system instructions.
5//!
6//! ## Overview
7//!
8//! The prompt builder constructs the complete LLM request by combining:
9//!
10//! 1. **System instructions**: Core instructions + action schema + tool schemas
11//! 2. **Session history**: Message history (truncated to most recent 50 non-system messages)
12//! 3. **Tool definitions**: Available tools and their schemas
13//!
14//! ## Components
15//!
16//! - **Action Schema**: JSON schema for the action protocol
17//! - **Tool Schema Block**: Formatted list of available tools
18//! - **History Truncation**: Keeps most recent messages to fit context limits
19//!
20//! ## Example
21//!
22//! ```rust,ignore
23//! use bob_runtime::prompt::build_llm_request;
24//! use bob_core::types::{SessionState, ToolDescriptor};
25//!
26//! let session = SessionState::default();
27//! let tools = vec![];
28//! let request = build_llm_request("openai:gpt-4o-mini", &session, &tools, "You are helpful.");
29//! ```
30
31use bob_core::types::{LlmRequest, Message, Role, SessionState, ToolDescriptor};
32
33/// Maximum number of non-system history messages to keep.
34const MAX_HISTORY: usize = 50;
35
36/// Options controlling prompt shape for different dispatch strategies.
37#[derive(Debug, Clone, Copy)]
38pub(crate) struct PromptBuildOptions {
39    pub include_action_schema: bool,
40    pub include_tool_schema: bool,
41}
42
43impl Default for PromptBuildOptions {
44    fn default() -> Self {
45        Self { include_action_schema: true, include_tool_schema: true }
46    }
47}
48
49/// Returns the JSON action-schema contract text (design doc §8.3.1).
50pub(crate) fn action_schema_prompt() -> String {
51    r#"You must respond with exactly one JSON object and no extra text.
52Schema:
53{
54  "type": "final" | "tool_call" | "ask_user",
55  "content": "string (required when type=final)",
56  "name": "string (required when type=tool_call)",
57  "arguments": "object (required when type=tool_call)",
58  "question": "string (required when type=ask_user)"
59}"#
60    .to_string()
61}
62
63/// Renders tool names, descriptions, and input schemas as a text block.
64///
65/// Returns an empty string when no tools are available.
66pub(crate) fn tool_schema_block(tools: &[ToolDescriptor]) -> String {
67    if tools.is_empty() {
68        return String::new();
69    }
70
71    let mut buf = String::from("Available tools:\n");
72    for tool in tools {
73        buf.push_str(&format!(
74            "\n- **{}**: {}\n  Input schema: {}\n",
75            tool.id,
76            tool.description,
77            serde_json::to_string_pretty(&tool.input_schema).unwrap_or_default(),
78        ));
79    }
80    buf
81}
82
83/// Assembles a complete `LlmRequest`:
84///   1. System message = core instructions + action schema + tool schemas
85///   2. Session history (truncated to most recent 50 non-system messages)
86///   3. `LlmRequest { model, messages, tools }`
87#[cfg_attr(
88    not(test),
89    expect(
90        dead_code,
91        reason = "compatibility wrapper retained for callers that use default prompt build options"
92    )
93)]
94pub(crate) fn build_llm_request(
95    model: &str,
96    session: &SessionState,
97    tools: &[ToolDescriptor],
98    system_instructions: &str,
99) -> LlmRequest {
100    build_llm_request_with_options(
101        model,
102        session,
103        tools,
104        system_instructions,
105        PromptBuildOptions::default(),
106    )
107}
108
109/// Assembles an `LlmRequest` with configurable schema/tool prompt sections.
110pub(crate) fn build_llm_request_with_options(
111    model: &str,
112    session: &SessionState,
113    tools: &[ToolDescriptor],
114    system_instructions: &str,
115    options: PromptBuildOptions,
116) -> LlmRequest {
117    // -- system message --------------------------------------------------
118    let mut system_content = system_instructions.to_string();
119    if options.include_action_schema {
120        system_content.push_str("\n\n");
121        system_content.push_str(&action_schema_prompt());
122    }
123
124    let tool_block =
125        if options.include_tool_schema { tool_schema_block(tools) } else { String::new() };
126    if !tool_block.is_empty() {
127        system_content.push_str("\n\n");
128        system_content.push_str(&tool_block);
129    }
130
131    let system_msg = Message { role: Role::System, content: system_content };
132
133    // -- history (truncated) ---------------------------------------------
134    let history = truncate_history(&session.messages, MAX_HISTORY);
135
136    // -- assemble --------------------------------------------------------
137    let mut messages = Vec::with_capacity(1 + history.len());
138    messages.push(system_msg);
139    messages.extend(history);
140
141    LlmRequest { model: model.to_string(), messages, tools: tools.to_vec() }
142}
143
144/// Keeps at most `max` non-system messages, dropping the oldest first.
145/// System messages are never dropped.
146pub(crate) fn truncate_history(messages: &[Message], max: usize) -> Vec<Message> {
147    let non_system_count = messages.iter().filter(|m| m.role != Role::System).count();
148
149    if non_system_count <= max {
150        return messages.to_vec();
151    }
152
153    let to_drop = non_system_count - max;
154    let mut dropped = 0usize;
155    let mut result = Vec::with_capacity(messages.len() - to_drop);
156
157    for m in messages {
158        if m.role == Role::System {
159            // System messages are always kept.
160            result.push(m.clone());
161        } else if dropped < to_drop {
162            // Drop the oldest non-system messages.
163            dropped += 1;
164        } else {
165            result.push(m.clone());
166        }
167    }
168
169    result
170}
171
172#[cfg(test)]
173mod tests {
174    use bob_core::types::{SessionState, TokenUsage, ToolSource};
175    use serde_json::json;
176
177    use super::*;
178
179    // ── Helpers ──────────────────────────────────────────────────────
180
181    fn make_tool(id: &str) -> ToolDescriptor {
182        ToolDescriptor {
183            id: id.to_string(),
184            description: format!("{id} description"),
185            input_schema: json!({"type": "object", "properties": {"path": {"type": "string"}}}),
186            source: ToolSource::Local,
187        }
188    }
189
190    fn msg(role: Role, content: &str) -> Message {
191        Message { role, content: content.to_string() }
192    }
193
194    // ── action_schema_prompt ─────────────────────────────────────────
195
196    #[test]
197    fn prompt_action_schema_contains_required_types() {
198        let schema = action_schema_prompt();
199        assert!(schema.contains("final"), "must mention 'final' action type");
200        assert!(schema.contains("tool_call"), "must mention 'tool_call' action type");
201        assert!(schema.contains("ask_user"), "must mention 'ask_user' action type");
202    }
203
204    #[test]
205    fn prompt_action_schema_mentions_json() {
206        let schema = action_schema_prompt();
207        assert!(schema.contains("JSON"), "must instruct the LLM to respond with JSON");
208    }
209
210    // ── tool_schema_block ────────────────────────────────────────────
211
212    #[test]
213    fn prompt_tool_schema_empty() {
214        let block = tool_schema_block(&[]);
215        // Empty tool list should produce a meaningful "no tools" indicator or empty block.
216        assert!(block.is_empty() || block.contains("No tools"), "empty tools produce no block");
217    }
218
219    #[test]
220    fn prompt_tool_schema_renders_names_and_descriptions() {
221        let tools = vec![make_tool("read_file"), make_tool("write_file")];
222        let block = tool_schema_block(&tools);
223        assert!(block.contains("read_file"), "must include tool name");
224        assert!(block.contains("read_file description"), "must include description");
225        assert!(block.contains("write_file"), "must include second tool");
226    }
227
228    #[test]
229    fn prompt_tool_schema_renders_input_schema() {
230        let tools = vec![make_tool("grep")];
231        let block = tool_schema_block(&tools);
232        assert!(block.contains("path"), "must include input_schema fields");
233    }
234
235    // ── truncate_history ─────────────────────────────────────────────
236
237    #[test]
238    fn prompt_truncate_noop_when_under_limit() {
239        let msgs = vec![msg(Role::User, "a"), msg(Role::Assistant, "b")];
240        let result = truncate_history(&msgs, 50);
241        assert_eq!(result.len(), 2);
242    }
243
244    #[test]
245    fn prompt_truncate_drops_oldest_non_system() {
246        let mut msgs: Vec<Message> =
247            (0..60).map(|i| msg(Role::User, &format!("msg-{i}"))).collect();
248        // Prepend a system message.
249        msgs.insert(0, msg(Role::System, "sys"));
250        let result = truncate_history(&msgs, 50);
251        // System message is kept, plus the 50 most recent non-system messages.
252        assert_eq!(result.len(), 51);
253        assert_eq!(result[0].role, Role::System);
254        // The oldest kept non-system should be msg-10 (dropped 0..10).
255        assert!(result[1].content.contains("msg-10"));
256    }
257
258    #[test]
259    fn prompt_truncate_keeps_all_system_messages() {
260        let msgs = vec![
261            msg(Role::System, "sys-1"),
262            msg(Role::User, "u1"),
263            msg(Role::System, "sys-2"),
264            msg(Role::User, "u2"),
265            msg(Role::Assistant, "a1"),
266        ];
267        let result = truncate_history(&msgs, 2);
268        // Both system messages kept + 2 most recent non-system (u2, a1).
269        assert_eq!(result.len(), 4);
270        let system_count = result.iter().filter(|m| m.role == Role::System).count();
271        assert_eq!(system_count, 2);
272    }
273
274    #[test]
275    fn prompt_truncate_preserves_order() {
276        let msgs = vec![
277            msg(Role::System, "sys"),
278            msg(Role::User, "old"),
279            msg(Role::User, "mid"),
280            msg(Role::User, "new"),
281        ];
282        let result = truncate_history(&msgs, 2);
283        assert_eq!(result.len(), 3); // sys + mid + new
284        assert_eq!(result[0].content, "sys");
285        assert_eq!(result[1].content, "mid");
286        assert_eq!(result[2].content, "new");
287    }
288
289    #[test]
290    fn prompt_truncate_empty_history() {
291        let result = truncate_history(&[], 50);
292        assert!(result.is_empty());
293    }
294
295    #[test]
296    fn prompt_truncate_exactly_at_limit() {
297        let msgs: Vec<Message> = (0..50).map(|i| msg(Role::User, &format!("u-{i}"))).collect();
298        let result = truncate_history(&msgs, 50);
299        assert_eq!(result.len(), 50, "no messages should be dropped at exact limit");
300        assert_eq!(result[0].content, "u-0");
301        assert_eq!(result[49].content, "u-49");
302    }
303
304    #[test]
305    fn prompt_truncate_single_message() {
306        let msgs = vec![msg(Role::User, "only")];
307        let result = truncate_history(&msgs, 50);
308        assert_eq!(result.len(), 1);
309        assert_eq!(result[0].content, "only");
310    }
311
312    #[test]
313    fn prompt_truncate_all_system_messages() {
314        let msgs = vec![msg(Role::System, "s1"), msg(Role::System, "s2"), msg(Role::System, "s3")];
315        let result = truncate_history(&msgs, 1);
316        // All system messages are kept regardless of limit.
317        assert_eq!(result.len(), 3);
318    }
319
320    #[test]
321    fn prompt_truncate_limit_zero_keeps_only_system() {
322        let msgs =
323            vec![msg(Role::System, "sys"), msg(Role::User, "u1"), msg(Role::Assistant, "a1")];
324        let result = truncate_history(&msgs, 0);
325        assert_eq!(result.len(), 1);
326        assert_eq!(result[0].role, Role::System);
327    }
328
329    #[test]
330    fn prompt_truncate_interleaved_system_preserves_all() {
331        let msgs = vec![
332            msg(Role::System, "init"),
333            msg(Role::User, "u1"),
334            msg(Role::System, "mid-sys"),
335            msg(Role::User, "u2"),
336            msg(Role::Assistant, "a1"),
337            msg(Role::System, "late-sys"),
338            msg(Role::User, "u3"),
339        ];
340        // Keep only 2 non-system messages → u2 dropped, keep a1 + u3
341        // Wait — there are 4 non-system: u1, u2, a1, u3. Keep 2 → drop u1, u2.
342        let result = truncate_history(&msgs, 2);
343        let system_count = result.iter().filter(|m| m.role == Role::System).count();
344        assert_eq!(system_count, 3, "all three system messages must survive");
345        let non_system: Vec<&str> =
346            result.iter().filter(|m| m.role != Role::System).map(|m| m.content.as_str()).collect();
347        assert_eq!(non_system, vec!["a1", "u3"]);
348    }
349
350    // ── build_llm_request ────────────────────────────────────────────
351
352    #[test]
353    fn prompt_build_empty_session() {
354        let session = SessionState::default();
355        let req = build_llm_request("test-model", &session, &[], "You are Bob.");
356        assert_eq!(req.model, "test-model");
357        // First message must be system.
358        assert_eq!(req.messages[0].role, Role::System);
359        assert!(req.messages[0].content.contains("You are Bob."));
360        // No history messages besides system.
361        assert_eq!(req.messages.len(), 1);
362        assert!(req.tools.is_empty());
363    }
364
365    #[test]
366    fn prompt_build_system_contains_action_schema() {
367        let session = SessionState::default();
368        let req = build_llm_request("m", &session, &[], "instructions");
369        assert!(req.messages[0].content.contains("JSON"));
370        assert!(req.messages[0].content.contains("tool_call"));
371    }
372
373    #[test]
374    fn prompt_build_includes_tools() {
375        let tools = vec![make_tool("t1")];
376        let session = SessionState::default();
377        let req = build_llm_request("m", &session, &tools, "inst");
378        assert_eq!(req.tools.len(), 1);
379        assert!(req.messages[0].content.contains("t1"));
380    }
381
382    #[test]
383    fn prompt_build_message_ordering() {
384        let session = SessionState {
385            messages: vec![msg(Role::User, "hello"), msg(Role::Assistant, "hi")],
386            total_usage: TokenUsage::default(),
387        };
388        let req = build_llm_request("m", &session, &[], "sys");
389        assert_eq!(req.messages[0].role, Role::System);
390        assert_eq!(req.messages[1].role, Role::User);
391        assert_eq!(req.messages[2].role, Role::Assistant);
392    }
393
394    #[test]
395    fn prompt_build_truncates_long_history() {
396        let messages: Vec<Message> = (0..60).map(|i| msg(Role::User, &format!("m-{i}"))).collect();
397        let session = SessionState { messages, total_usage: TokenUsage::default() };
398        let req = build_llm_request("m", &session, &[], "sys");
399        // 1 system + 50 truncated history = 51
400        assert_eq!(req.messages.len(), 51);
401        assert_eq!(req.messages[0].role, Role::System);
402    }
403}