rustvani 0.4.0-dev.0

Voice AI framework for Rust — real-time speech pipelines with STT, LLM, TTS, and Dhara conversation flows
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
//! Shared conversation context.
//!
//! Owned by both aggregators via `Arc<Mutex<LLMContext>>`.
//! The LLM service reads it; the aggregators write to it.

use std::sync::{Arc, Mutex};

use serde::{Deserialize, Serialize};

use crate::adapters::schemas::{ToolChoice, ToolsSchema};

// ---------------------------------------------------------------------------
// ToolCall — a single function invocation requested by the model
// ---------------------------------------------------------------------------

/// A function call the model wants to execute.
///
/// Streamed as argument-string fragments during SSE; by the time this struct
/// is constructed, `arguments` is the fully accumulated JSON string.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
    /// Unique call ID assigned by the model (e.g. `"call_abc123"`).
    pub id: String,
    /// Name of the function to invoke.
    pub function_name: String,
    /// Raw JSON string of the function arguments.
    pub arguments: String,
}

// ---------------------------------------------------------------------------
// Message — type-safe conversation turn
// ---------------------------------------------------------------------------

/// A single turn in the conversation.
///
/// Each variant enforces what fields are valid for that role, unlike the
/// Python dict approach where any key can appear on any message.
#[derive(Debug, Clone)]
pub enum Message {
    /// System-level instruction. Typically the first message.
    System { content: String },

    /// User turn — transcribed speech or typed text.
    User { content: String },

    /// Assistant turn — may be text, tool calls, or both.
    Assistant {
        /// `None` when the model responds with only tool calls.
        content: Option<String>,
        /// `None` for plain text responses.
        tool_calls: Option<Vec<ToolCall>>,
    },

    /// Result of a tool/function call, sent back to the model.
    ToolResult {
        /// Matches the `id` of the `ToolCall` this responds to.
        tool_call_id: String,
        /// Serialized result (typically JSON).
        content: String,
    },
}

// ---------------------------------------------------------------------------
// LLMContext
// ---------------------------------------------------------------------------

/// Shared conversation context passed between aggregators and the LLM service.
///
/// ## Turn transaction (see `doc/turn-acid.md`)
///
/// Beyond the committed `messages`, the context carries a small **staging
/// buffer** used to make a tool round atomic. The LLM stages an assistant
/// `tool_calls` message and its `ToolResult`s with [`stage_message`], then
/// [`commit`]s them into `messages` once the whole round is in. If the turn is
/// interrupted mid-round (barge-in), the in-flight future is dropped and
/// [`rollback`] discards the staged messages — so `messages` never retains an
/// assistant `tool_calls` without its matching results (which would make the
/// next request malformed).
///
/// Commit happens at the **round boundary**, not turn end: the Dhara transition
/// hook runs between rounds and manipulates `messages` directly, so each round's
/// content must be in `messages` by the time the hook sees it. Plain user and
/// assistant text messages are committed directly (no staging) — they carry no
/// orphan risk.
#[derive(Debug, Clone)]
pub struct LLMContext {
    /// System prompt — prepended as `Message::System` in `to_api_messages()`.
    pub system_prompt: Option<String>,
    /// Conversation history (user turns, assistant turns, tool results).
    pub messages: Vec<Message>,
    /// Available tools for this context. `None` = no function calling.
    pub tools: Option<ToolsSchema>,
    /// How the model should pick tools. `None` = provider default (usually "auto").
    pub tool_choice: Option<ToolChoice>,
    /// Uncommitted messages for the in-flight tool round. Empty between rounds.
    /// Not part of `to_api_messages()` — committed into `messages` at the round
    /// boundary, or discarded by `rollback()` on interruption.
    staged: Vec<Message>,
    /// Monotonic turn identity, bumped by `begin_turn()`. Seeds turn-level
    /// isolation; full epoch fencing across the agent bus is future work.
    epoch: u64,
}

impl LLMContext {
    pub fn new(system_prompt: Option<String>) -> Self {
        Self {
            system_prompt,
            messages: Vec::new(),
            tools: None,
            tool_choice: None,
            staged: Vec::new(),
            epoch: 0,
        }
    }

    /// Create a context with tools configured.
    pub fn with_tools(
        system_prompt: Option<String>,
        tools: ToolsSchema,
        tool_choice: Option<ToolChoice>,
    ) -> Self {
        Self {
            system_prompt,
            messages: Vec::new(),
            tools: Some(tools),
            tool_choice,
            staged: Vec::new(),
            epoch: 0,
        }
    }

    // ---- Convenience push methods ----

    /// Append any message variant.
    pub fn push_message(&mut self, msg: Message) {
        self.messages.push(msg);
    }

    /// Append a user turn.
    pub fn add_user_message(&mut self, content: impl Into<String>) {
        self.messages.push(Message::User {
            content: content.into(),
        });
    }

    /// Append a plain-text assistant turn (no tool calls).
    pub fn add_assistant_message(&mut self, content: impl Into<String>) {
        self.messages.push(Message::Assistant {
            content: Some(content.into()),
            tool_calls: None,
        });
    }

    /// Append an assistant turn that contains tool calls.
    pub fn add_assistant_tool_calls(
        &mut self,
        content: Option<String>,
        tool_calls: Vec<ToolCall>,
    ) {
        self.messages.push(Message::Assistant {
            content,
            tool_calls: Some(tool_calls),
        });
    }

    /// Append a tool result.
    pub fn add_tool_result(
        &mut self,
        tool_call_id: impl Into<String>,
        content: impl Into<String>,
    ) {
        self.messages.push(Message::ToolResult {
            tool_call_id: tool_call_id.into(),
            content: content.into(),
        });
    }

    // ---- Turn transaction (see doc/turn-acid.md) ----

    /// Open a new turn. Bumps the [`epoch`](Self::epoch) (turn identity) and
    /// discards any leftover staged messages — an implicit rollback of a prior,
    /// interrupted round. Returns the new epoch.
    pub fn begin_turn(&mut self) -> u64 {
        self.epoch = self.epoch.wrapping_add(1);
        self.staged.clear();
        self.epoch
    }

    /// Current turn epoch.
    pub fn epoch(&self) -> u64 {
        self.epoch
    }

    /// Number of staged (uncommitted) messages. Exposed for tests/diagnostics.
    pub fn staged_len(&self) -> usize {
        self.staged.len()
    }

    /// Stage a message for the in-flight round. Not visible to
    /// `to_api_messages()` until [`commit`](Self::commit)ted.
    pub fn stage_message(&mut self, msg: Message) {
        self.staged.push(msg);
    }

    /// Stage an assistant turn that contains tool calls.
    pub fn stage_assistant_tool_calls(
        &mut self,
        content: Option<String>,
        tool_calls: Vec<ToolCall>,
    ) {
        self.staged.push(Message::Assistant {
            content,
            tool_calls: Some(tool_calls),
        });
    }

    /// Stage a tool result.
    pub fn stage_tool_result(
        &mut self,
        tool_call_id: impl Into<String>,
        content: impl Into<String>,
    ) {
        self.staged.push(Message::ToolResult {
            tool_call_id: tool_call_id.into(),
            content: content.into(),
        });
    }

    /// Commit the staged round into `messages` atomically. Drops any assistant
    /// `tool_calls` whose ids are not all answered by a staged `ToolResult`
    /// (consistency: the API rejects an assistant `tool_calls` without matching
    /// results), along with their dangling results. Returns the number of
    /// messages committed.
    pub fn commit(&mut self) -> usize {
        if self.staged.is_empty() {
            return 0;
        }
        let mut staged = std::mem::take(&mut self.staged);
        Self::repair_orphan_tool_calls(&mut staged);
        let n = staged.len();
        self.messages.append(&mut staged);
        n
    }

    /// Discard the staged round without touching committed `messages`.
    /// Idempotent. Called on interruption so an aborted tool round leaves no
    /// orphan behind.
    pub fn rollback(&mut self) {
        if !self.staged.is_empty() {
            log::debug!(
                "LLMContext: rolling back {} staged message(s)",
                self.staged.len()
            );
            self.staged.clear();
        }
    }

    /// Remove assistant `tool_calls` messages whose calls are not all answered
    /// by a `ToolResult` in the same staged batch, plus any `ToolResult` that
    /// then answers nothing. Keeps plain text messages untouched.
    fn repair_orphan_tool_calls(staged: &mut Vec<Message>) {
        use std::collections::HashSet;

        let answered: HashSet<&str> = staged
            .iter()
            .filter_map(|m| match m {
                Message::ToolResult { tool_call_id, .. } => Some(tool_call_id.as_str()),
                _ => None,
            })
            .collect();

        // Ids of tool calls that survive (all their calls are answered).
        let mut kept_call_ids: HashSet<String> = HashSet::new();
        let mut keep: Vec<bool> = Vec::with_capacity(staged.len());
        for m in staged.iter() {
            let k = match m {
                Message::Assistant { tool_calls: Some(tcs), .. } => {
                    let ok = tcs.iter().all(|tc| answered.contains(tc.id.as_str()));
                    if ok {
                        for tc in tcs {
                            kept_call_ids.insert(tc.id.clone());
                        }
                    } else {
                        log::warn!(
                            "LLMContext: dropping orphaned assistant tool_calls at commit \
                             (unanswered tool call)"
                        );
                    }
                    ok
                }
                _ => true,
            };
            keep.push(k);
        }
        // Second pass: drop tool results whose assistant message was dropped.
        let mut i = 0;
        staged.retain(|m| {
            let k = keep[i];
            i += 1;
            match m {
                Message::ToolResult { tool_call_id, .. } if k => {
                    kept_call_ids.contains(tool_call_id.as_str())
                }
                _ => k,
            }
        });
    }

    /// Build the full messages array for the API call.
    ///
    /// System prompt is prepended as the first message if present.
    /// The adapter then converts these `Message` variants into the
    /// provider's wire format.
    pub fn to_api_messages(&self) -> Vec<Message> {
        let mut result = Vec::new();
        if let Some(sys) = &self.system_prompt {
            result.push(Message::System {
                content: sys.clone(),
            });
        }
        result.extend(self.messages.clone());
        result
    }

    /// Rough token estimate: ~4 chars per token, covers all message fields.
    pub fn estimate_tokens(&self) -> usize {
        let mut chars: usize = self.system_prompt.as_deref().map_or(0, |s| s.len());
        for msg in &self.messages {
            chars += match msg {
                Message::System { content } => content.len(),
                Message::User { content } => content.len(),
                Message::Assistant { content, tool_calls } => {
                    content.as_deref().map_or(0, |c| c.len())
                        + tool_calls.as_ref().map_or(0, |tcs| {
                            tcs.iter()
                                .map(|tc| tc.function_name.len() + tc.arguments.len() + 20)
                                .sum()
                        })
                }
                Message::ToolResult { content, .. } => content.len(),
            };
        }
        chars.saturating_div(4)
    }

    /// Drop oldest conversation groups until the estimated token count fits
    /// within `context_window_tokens * 0.8` (reserves headroom for the reply).
    ///
    /// A "group" is everything from one User message up to (but not including)
    /// the next User message, so Assistant tool-call + ToolResult pairs are
    /// never orphaned. Stops if no safe drop point remains.
    pub fn trim_to_context_budget(&mut self, context_window_tokens: usize) {
        let budget = (context_window_tokens as f64 * 0.8) as usize;
        loop {
            if self.estimate_tokens() <= budget {
                break;
            }
            // Find the first User message that has another User message after it.
            let first_user = self
                .messages
                .iter()
                .position(|m| matches!(m, Message::User { .. }));
            let next_user = first_user.and_then(|i| {
                self.messages[i + 1..]
                    .iter()
                    .position(|m| matches!(m, Message::User { .. }))
                    .map(|j| i + 1 + j)
            });
            match (first_user, next_user) {
                (Some(start), Some(end)) => {
                    let dropped = end - start;
                    self.messages.drain(start..end);
                    log::warn!(
                        "LLMContext: trimmed {} messages to fit {}-token budget",
                        dropped,
                        context_window_tokens
                    );
                }
                _ => {
                    log::warn!(
                        "LLMContext: context near limit ({} estimated tokens) but cannot safely trim further",
                        self.estimate_tokens()
                    );
                    break;
                }
            }
        }
    }
}

/// Convenience: create a shared context ready for pipeline use.
pub fn shared_context(system_prompt: Option<String>) -> Arc<Mutex<LLMContext>> {
    Arc::new(Mutex::new(LLMContext::new(system_prompt)))
}

/// Convenience: create a shared context with tools configured.
pub fn shared_context_with_tools(
    system_prompt: Option<String>,
    tools: ToolsSchema,
    tool_choice: Option<ToolChoice>,
) -> Arc<Mutex<LLMContext>> {
    Arc::new(Mutex::new(LLMContext::with_tools(
        system_prompt,
        tools,
        tool_choice,
    )))
}

// ---------------------------------------------------------------------------
// Tests — turn transaction (see doc/turn-acid.md)
// ---------------------------------------------------------------------------

#[cfg(test)]
mod tests {
    use super::*;

    fn tc(id: &str, name: &str) -> ToolCall {
        ToolCall {
            id: id.into(),
            function_name: name.into(),
            arguments: "{}".into(),
        }
    }

    fn assistant_text(messages: &[Message]) -> Vec<&str> {
        messages
            .iter()
            .filter_map(|m| match m {
                Message::Assistant { content: Some(c), tool_calls: None } => Some(c.as_str()),
                _ => None,
            })
            .collect()
    }

    #[test]
    fn staged_is_invisible_until_commit() {
        let mut ctx = LLMContext::new(None);
        ctx.add_user_message("hello");
        ctx.stage_assistant_tool_calls(None, vec![tc("call_1", "lookup")]);
        ctx.stage_tool_result("call_1", "ok");

        // Nothing committed yet — to_api_messages must not leak staged work.
        assert_eq!(ctx.staged_len(), 2);
        assert_eq!(ctx.to_api_messages().len(), 1); // just the user message
    }

    #[test]
    fn commit_splices_full_round() {
        let mut ctx = LLMContext::new(None);
        ctx.add_user_message("status of 4471?");
        ctx.stage_assistant_tool_calls(None, vec![tc("call_1", "lookup")]);
        ctx.stage_tool_result("call_1", "shipped");

        let n = ctx.commit();
        assert_eq!(n, 2);
        assert_eq!(ctx.staged_len(), 0);
        // user + assistant(tool_calls) + tool_result
        assert_eq!(ctx.messages.len(), 3);
        assert!(matches!(ctx.messages[1], Message::Assistant { tool_calls: Some(_), .. }));
        assert!(matches!(ctx.messages[2], Message::ToolResult { .. }));
    }

    #[test]
    fn rollback_discards_orphaned_round() {
        // Interrupt after staging the tool_calls but before its result.
        let mut ctx = LLMContext::new(None);
        ctx.add_user_message("status of 4471?");
        ctx.stage_assistant_tool_calls(None, vec![tc("call_1", "lookup")]);

        ctx.rollback();
        assert_eq!(ctx.staged_len(), 0);
        // Committed history is untouched — no orphaned tool_calls remain.
        assert_eq!(ctx.messages.len(), 1);
        assert!(matches!(ctx.messages[0], Message::User { .. }));
    }

    #[test]
    fn commit_drops_orphan_tool_calls_for_consistency() {
        // A defensive commit of a round missing one result must not splice an
        // assistant tool_calls without all its matching ToolResults.
        let mut ctx = LLMContext::new(None);
        ctx.stage_assistant_tool_calls(None, vec![tc("call_1", "a"), tc("call_2", "b")]);
        ctx.stage_tool_result("call_1", "done"); // call_2 unanswered

        let n = ctx.commit();
        assert_eq!(n, 0, "orphaned round must be dropped entirely");
        assert!(ctx.messages.is_empty());
    }

    #[test]
    fn commit_keeps_plain_text_assistant() {
        let mut ctx = LLMContext::new(None);
        ctx.stage_message(Message::Assistant {
            content: Some("hi there".into()),
            tool_calls: None,
        });
        assert_eq!(ctx.commit(), 1);
        assert_eq!(assistant_text(&ctx.messages), vec!["hi there"]);
    }

    #[test]
    fn begin_turn_bumps_epoch_and_clears_stale_staged() {
        let mut ctx = LLMContext::new(None);
        let e0 = ctx.epoch();
        ctx.stage_assistant_tool_calls(None, vec![tc("call_1", "lookup")]); // leftover

        let e1 = ctx.begin_turn();
        assert_eq!(e1, e0 + 1);
        assert_eq!(ctx.staged_len(), 0, "begin_turn discards a prior interrupted round");
    }
}