1pub mod hooks;
2pub mod streaming;
3
4use super::{
5 Agent,
6 completion::{DynamicContextStore, build_completion_request},
7};
8use crate::{
9 OneOrMany,
10 completion::{CompletionModel, Document, Message, PromptError, Usage},
11 json_utils,
12 memory::ConversationMemory,
13 message::{AssistantContent, ToolChoice, ToolResultContent, UserContent},
14 tool::server::ToolServerHandle,
15 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
16};
17use futures::{StreamExt, stream};
18use hooks::{HookAction, PromptHook, ToolCallHookAction};
19use serde::{Deserialize, Serialize};
20use std::{
21 future::IntoFuture,
22 marker::PhantomData,
23 sync::{
24 Arc,
25 atomic::{AtomicU64, Ordering},
26 },
27};
28use tracing::info_span;
29use tracing::{Instrument, span::Id};
30
31pub trait PromptType {}
32pub struct Standard;
33pub struct Extended;
34
35impl PromptType for Standard {}
36impl PromptType for Extended {}
37
38pub struct PromptRequest<S, M, P>
47where
48 S: PromptType,
49 M: CompletionModel,
50 P: PromptHook<M>,
51{
52 prompt: Message,
54 chat_history: Option<Vec<Message>>,
56 max_turns: usize,
58
59 model: Arc<M>,
62 agent_name: Option<String>,
64 preamble: Option<String>,
66 static_context: Vec<Document>,
68 temperature: Option<f64>,
70 max_tokens: Option<u64>,
72 additional_params: Option<serde_json::Value>,
74 tool_server_handle: ToolServerHandle,
76 dynamic_context: DynamicContextStore,
78 tool_choice: Option<ToolChoice>,
80
81 state: PhantomData<S>,
83 hook: Option<P>,
85 concurrency: usize,
87 output_schema: Option<schemars::Schema>,
89 memory: Option<Arc<dyn ConversationMemory>>,
91 conversation_id: Option<String>,
93}
94
95impl<M, P> PromptRequest<Standard, M, P>
96where
97 M: CompletionModel,
98 P: PromptHook<M>,
99{
100 pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
102 PromptRequest {
103 prompt: prompt.into(),
104 chat_history: None,
105 max_turns: agent.default_max_turns.unwrap_or_default(),
106 model: agent.model.clone(),
107 agent_name: agent.name.clone(),
108 preamble: agent.preamble.clone(),
109 static_context: agent.static_context.clone(),
110 temperature: agent.temperature,
111 max_tokens: agent.max_tokens,
112 additional_params: agent.additional_params.clone(),
113 tool_server_handle: agent.tool_server_handle.clone(),
114 dynamic_context: agent.dynamic_context.clone(),
115 tool_choice: agent.tool_choice.clone(),
116 state: PhantomData,
117 hook: agent.hook.clone(),
118 concurrency: 1,
119 output_schema: agent.output_schema.clone(),
120 memory: agent.memory.clone(),
121 conversation_id: agent.default_conversation_id.clone(),
122 }
123 }
124}
125
126impl<S, M, P> PromptRequest<S, M, P>
127where
128 S: PromptType,
129 M: CompletionModel,
130 P: PromptHook<M>,
131{
132 pub fn extended_details(self) -> PromptRequest<Extended, M, P> {
139 PromptRequest {
140 prompt: self.prompt,
141 chat_history: self.chat_history,
142 max_turns: self.max_turns,
143 model: self.model,
144 agent_name: self.agent_name,
145 preamble: self.preamble,
146 static_context: self.static_context,
147 temperature: self.temperature,
148 max_tokens: self.max_tokens,
149 additional_params: self.additional_params,
150 tool_server_handle: self.tool_server_handle,
151 dynamic_context: self.dynamic_context,
152 tool_choice: self.tool_choice,
153 state: PhantomData,
154 hook: self.hook,
155 concurrency: self.concurrency,
156 output_schema: self.output_schema,
157 memory: self.memory,
158 conversation_id: self.conversation_id,
159 }
160 }
161
162 pub fn max_turns(mut self, depth: usize) -> Self {
165 self.max_turns = depth;
166 self
167 }
168
169 pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
172 self.concurrency = concurrency;
173 self
174 }
175
176 pub fn with_history<I, T>(mut self, history: I) -> Self
178 where
179 I: IntoIterator<Item = T>,
180 T: Into<Message>,
181 {
182 self.chat_history = Some(history.into_iter().map(Into::into).collect());
183 self
184 }
185
186 pub fn conversation(mut self, id: impl Into<String>) -> Self {
191 self.conversation_id = Some(id.into());
192 self
193 }
194
195 pub fn without_memory(mut self) -> Self {
199 self.memory = None;
200 self.conversation_id = None;
201 self
202 }
203
204 pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<S, M, P2>
207 where
208 P2: PromptHook<M>,
209 {
210 PromptRequest {
211 prompt: self.prompt,
212 chat_history: self.chat_history,
213 max_turns: self.max_turns,
214 model: self.model,
215 agent_name: self.agent_name,
216 preamble: self.preamble,
217 static_context: self.static_context,
218 temperature: self.temperature,
219 max_tokens: self.max_tokens,
220 additional_params: self.additional_params,
221 tool_server_handle: self.tool_server_handle,
222 dynamic_context: self.dynamic_context,
223 tool_choice: self.tool_choice,
224 state: PhantomData,
225 hook: Some(hook),
226 concurrency: self.concurrency,
227 output_schema: self.output_schema,
228 memory: self.memory,
229 conversation_id: self.conversation_id,
230 }
231 }
232}
233
234impl<M, P> IntoFuture for PromptRequest<Standard, M, P>
238where
239 M: CompletionModel + 'static,
240 P: PromptHook<M> + 'static,
241{
242 type Output = Result<String, PromptError>;
243 type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
244
245 fn into_future(self) -> Self::IntoFuture {
246 Box::pin(self.send())
247 }
248}
249
250impl<M, P> IntoFuture for PromptRequest<Extended, M, P>
251where
252 M: CompletionModel + 'static,
253 P: PromptHook<M> + 'static,
254{
255 type Output = Result<PromptResponse, PromptError>;
256 type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
257
258 fn into_future(self) -> Self::IntoFuture {
259 Box::pin(self.send())
260 }
261}
262
263impl<M, P> PromptRequest<Standard, M, P>
264where
265 M: CompletionModel,
266 P: PromptHook<M>,
267{
268 async fn send(self) -> Result<String, PromptError> {
269 self.extended_details().send().await.map(|resp| resp.output)
270 }
271}
272
273#[derive(Debug, Clone, Serialize, Deserialize)]
274#[non_exhaustive]
275pub struct PromptResponse {
276 pub output: String,
277 pub usage: Usage,
278 pub messages: Option<Vec<Message>>,
279}
280
281impl std::fmt::Display for PromptResponse {
282 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283 self.output.fmt(f)
284 }
285}
286
287impl PromptResponse {
288 pub fn new(output: impl Into<String>, usage: Usage) -> Self {
289 Self {
290 output: output.into(),
291 usage,
292 messages: None,
293 }
294 }
295
296 pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
297 self.messages = Some(messages);
298 self
299 }
300}
301
302#[derive(Debug, Clone, Serialize, Deserialize)]
303pub struct TypedPromptResponse<T> {
304 pub output: T,
305 pub usage: Usage,
306}
307
308impl<T> TypedPromptResponse<T> {
309 pub fn new(output: T, usage: Usage) -> Self {
310 Self { output, usage }
311 }
312}
313
314const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
315
316fn build_history_for_request(
318 chat_history: Option<&[Message]>,
319 new_messages: &[Message],
320) -> Vec<Message> {
321 let input = chat_history.unwrap_or(&[]);
322 input.iter().chain(new_messages.iter()).cloned().collect()
323}
324
325fn build_full_history(
327 chat_history: Option<&[Message]>,
328 new_messages: Vec<Message>,
329) -> Vec<Message> {
330 let input = chat_history.unwrap_or(&[]);
331 input.iter().cloned().chain(new_messages).collect()
332}
333
334fn is_empty_assistant_turn(choice: &OneOrMany<AssistantContent>) -> bool {
335 choice.len() == 1
336 && matches!(
337 choice.first(),
338 AssistantContent::Text(text) if text.text.is_empty()
339 )
340}
341
342impl<M, P> PromptRequest<Extended, M, P>
343where
344 M: CompletionModel,
345 P: PromptHook<M>,
346{
347 fn agent_name(&self) -> &str {
348 self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
349 }
350
351 async fn send(self) -> Result<PromptResponse, PromptError> {
352 let agent_span = if tracing::Span::current().is_disabled() {
353 info_span!(
354 "invoke_agent",
355 gen_ai.operation.name = "invoke_agent",
356 gen_ai.agent.name = self.agent_name(),
357 gen_ai.system_instructions = self.preamble,
358 gen_ai.prompt = tracing::field::Empty,
359 gen_ai.completion = tracing::field::Empty,
360 gen_ai.usage.input_tokens = tracing::field::Empty,
361 gen_ai.usage.output_tokens = tracing::field::Empty,
362 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
363 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
364 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
365 )
366 } else {
367 tracing::Span::current()
368 };
369
370 if let Some(text) = self.prompt.rag_text() {
371 agent_span.record("gen_ai.prompt", text);
372 }
373
374 let agent_name_for_span = self.agent_name.clone();
375 let (chat_history, memory_handle) = match self.chat_history {
380 Some(history) => (Some(history), None),
381 None => match (self.memory, self.conversation_id) {
382 (Some(memory), Some(id)) => {
383 let loaded = memory.load(&id).await?;
384 (Some(loaded), Some((memory, id)))
385 }
386 _ => (None, None),
387 },
388 };
389 let mut new_messages: Vec<Message> = vec![self.prompt.clone()];
390
391 let mut current_max_turns = 0;
392 let mut usage = Usage::new();
393 let current_span_id: AtomicU64 = AtomicU64::new(0);
394
395 let last_prompt = loop {
397 let Some((prompt_ref, history_for_current_turn)) = new_messages.split_last() else {
399 return Err(PromptError::prompt_cancelled(
400 build_full_history(chat_history.as_deref(), new_messages),
401 "prompt loop lost its pending prompt",
402 ));
403 };
404 let prompt = prompt_ref.clone();
405
406 if current_max_turns > self.max_turns + 1 {
407 break prompt;
408 }
409
410 current_max_turns += 1;
411
412 if self.max_turns > 1 {
413 tracing::info!(
414 "Current conversation depth: {}/{}",
415 current_max_turns,
416 self.max_turns
417 );
418 }
419
420 let history_for_hook =
422 build_history_for_request(chat_history.as_deref(), history_for_current_turn);
423
424 if let Some(ref hook) = self.hook
425 && let HookAction::Terminate { reason } =
426 hook.on_completion_call(&prompt, &history_for_hook).await
427 {
428 return Err(PromptError::prompt_cancelled(
429 build_full_history(chat_history.as_deref(), new_messages),
430 reason,
431 ));
432 }
433
434 let span = tracing::Span::current();
435 let chat_span = info_span!(
436 target: "rig::agent_chat",
437 parent: &span,
438 "chat",
439 gen_ai.operation.name = "chat",
440 gen_ai.agent.name = agent_name_for_span.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
441 gen_ai.system_instructions = self.preamble,
442 gen_ai.provider.name = tracing::field::Empty,
443 gen_ai.request.model = tracing::field::Empty,
444 gen_ai.response.id = tracing::field::Empty,
445 gen_ai.response.model = tracing::field::Empty,
446 gen_ai.usage.output_tokens = tracing::field::Empty,
447 gen_ai.usage.input_tokens = tracing::field::Empty,
448 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
449 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
450 gen_ai.usage.reasoning_tokens = tracing::field::Empty,
451 gen_ai.input.messages = tracing::field::Empty,
452 gen_ai.output.messages = tracing::field::Empty,
453 );
454
455 let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
456 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
457 chat_span.follows_from(id).to_owned()
458 } else {
459 chat_span
460 };
461
462 if let Some(id) = chat_span.id() {
463 current_span_id.store(id.into_u64(), Ordering::SeqCst);
464 };
465
466 let history_for_request =
468 build_history_for_request(chat_history.as_deref(), history_for_current_turn);
469
470 let resp = build_completion_request(
471 &self.model,
472 prompt.clone(),
473 &history_for_request,
474 self.preamble.as_deref(),
475 &self.static_context,
476 self.temperature,
477 self.max_tokens,
478 self.additional_params.as_ref(),
479 self.tool_choice.as_ref(),
480 &self.tool_server_handle,
481 &self.dynamic_context,
482 self.output_schema.as_ref(),
483 )
484 .await?
485 .send()
486 .instrument(chat_span.clone())
487 .await?;
488
489 usage += resp.usage;
490
491 if let Some(ref hook) = self.hook
492 && let HookAction::Terminate { reason } =
493 hook.on_completion_response(&prompt, &resp).await
494 {
495 return Err(PromptError::prompt_cancelled(
496 build_full_history(chat_history.as_deref(), new_messages),
497 reason,
498 ));
499 }
500
501 let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
502 .choice
503 .iter()
504 .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
505
506 if !is_empty_assistant_turn(&resp.choice) {
510 new_messages.push(Message::Assistant {
511 id: resp.message_id.clone(),
512 content: resp.choice.clone(),
513 });
514 }
515
516 if tool_calls.is_empty() {
517 let merged_texts = texts
518 .into_iter()
519 .filter_map(|content| {
520 if let AssistantContent::Text(text) = content {
521 Some(text.text.clone())
522 } else {
523 None
524 }
525 })
526 .collect::<Vec<_>>()
527 .join("\n");
528
529 if self.max_turns > 1 {
530 tracing::info!("Depth reached: {}/{}", current_max_turns, self.max_turns);
531 }
532
533 agent_span.record("gen_ai.completion", &merged_texts);
534 agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
535 agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
536 agent_span.record(
537 "gen_ai.usage.cache_read.input_tokens",
538 usage.cached_input_tokens,
539 );
540 agent_span.record(
541 "gen_ai.usage.cache_creation.input_tokens",
542 usage.cache_creation_input_tokens,
543 );
544 agent_span.record("gen_ai.usage.reasoning_tokens", usage.reasoning_tokens);
545
546 if let Some((memory, id)) = memory_handle.as_ref()
547 && let Err(err) = memory.append(id, new_messages.clone()).await
548 {
549 tracing::warn!(
550 error = %err,
551 conversation_id = %id,
552 "conversation memory append failed; returning model response anyway"
553 );
554 }
555
556 return Ok(PromptResponse::new(merged_texts, usage).with_messages(new_messages));
557 }
558
559 let hook = self.hook.clone();
560 let tool_server_handle = self.tool_server_handle.clone();
561
562 let full_history_for_errors =
564 build_full_history(chat_history.as_deref(), new_messages.clone());
565
566 let tool_calls: Vec<AssistantContent> = tool_calls.into_iter().cloned().collect();
567 let tool_content = stream::iter(tool_calls)
568 .map(|choice| {
569 let hook1 = hook.clone();
570 let hook2 = hook.clone();
571 let tool_server_handle = tool_server_handle.clone();
572
573 let tool_span = info_span!(
574 "execute_tool",
575 gen_ai.operation.name = "execute_tool",
576 gen_ai.tool.type = "function",
577 gen_ai.tool.name = tracing::field::Empty,
578 gen_ai.tool.call.id = tracing::field::Empty,
579 gen_ai.tool.call.arguments = tracing::field::Empty,
580 gen_ai.tool.call.result = tracing::field::Empty
581 );
582
583 let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
584 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
585 tool_span.follows_from(id).to_owned()
586 } else {
587 tool_span
588 };
589
590 if let Some(id) = tool_span.id() {
591 current_span_id.store(id.into_u64(), Ordering::SeqCst);
592 };
593
594 let cloned_history_for_error = full_history_for_errors.clone();
596
597 async move {
598 if let AssistantContent::ToolCall(tool_call) = choice {
599 let tool_name = &tool_call.function.name;
600 let args =
601 json_utils::value_to_json_string(&tool_call.function.arguments);
602 let internal_call_id = nanoid::nanoid!();
603 let tool_span = tracing::Span::current();
604 tool_span.record("gen_ai.tool.name", tool_name);
605 tool_span.record("gen_ai.tool.call.id", &tool_call.id);
606 tool_span.record("gen_ai.tool.call.arguments", &args);
607 if let Some(hook) = hook1 {
608 let action = hook
609 .on_tool_call(
610 tool_name,
611 tool_call.call_id.clone(),
612 &internal_call_id,
613 &args,
614 )
615 .await;
616
617 if let ToolCallHookAction::Terminate { reason } = action {
618 return Err(PromptError::prompt_cancelled(
619 cloned_history_for_error,
620 reason,
621 ));
622 }
623
624 if let ToolCallHookAction::Skip { reason } = action {
625 tracing::info!(
627 tool_name = tool_name,
628 reason = reason,
629 "Tool call rejected"
630 );
631 if let Some(call_id) = tool_call.call_id.clone() {
632 return Ok(UserContent::tool_result_with_call_id(
633 tool_call.id.clone(),
634 call_id,
635 OneOrMany::one(reason.into()),
636 ));
637 } else {
638 return Ok(UserContent::tool_result(
639 tool_call.id.clone(),
640 OneOrMany::one(reason.into()),
641 ));
642 }
643 }
644 }
645 let output = match tool_server_handle.call_tool(tool_name, &args).await
646 {
647 Ok(res) => res,
648 Err(e) => {
649 tracing::warn!("Error while executing tool: {e}");
650 e.to_string()
651 }
652 };
653 if let Some(hook) = hook2
654 && let HookAction::Terminate { reason } = hook
655 .on_tool_result(
656 tool_name,
657 tool_call.call_id.clone(),
658 &internal_call_id,
659 &args,
660 &output.to_string(),
661 )
662 .await
663 {
664 return Err(PromptError::prompt_cancelled(
665 cloned_history_for_error,
666 reason,
667 ));
668 }
669
670 tool_span.record("gen_ai.tool.call.result", &output);
671 tracing::info!(
672 "executed tool {tool_name} with args {args}. result: {output}"
673 );
674 if let Some(call_id) = tool_call.call_id.clone() {
675 Ok(UserContent::tool_result_with_call_id(
676 tool_call.id.clone(),
677 call_id,
678 ToolResultContent::from_tool_output(output),
679 ))
680 } else {
681 Ok(UserContent::tool_result(
682 tool_call.id.clone(),
683 ToolResultContent::from_tool_output(output),
684 ))
685 }
686 } else {
687 Err(PromptError::prompt_cancelled(
688 Vec::new(),
689 "tool execution received non-tool assistant content",
690 ))
691 }
692 }
693 .instrument(tool_span)
694 })
695 .buffer_unordered(self.concurrency)
696 .collect::<Vec<Result<UserContent, PromptError>>>()
697 .await
698 .into_iter()
699 .collect::<Result<Vec<_>, _>>()?;
700
701 let Some(content) = OneOrMany::from_iter_optional(tool_content) else {
702 return Err(PromptError::prompt_cancelled(
703 build_full_history(chat_history.as_deref(), new_messages),
704 "tool execution produced no tool results",
705 ));
706 };
707
708 new_messages.push(Message::User { content });
709 };
710
711 Err(PromptError::MaxTurnsError {
713 max_turns: self.max_turns,
714 chat_history: build_full_history(chat_history.as_deref(), new_messages).into(),
715 prompt: last_prompt.into(),
716 })
717 }
718}
719
720use crate::completion::StructuredOutputError;
725use schemars::{JsonSchema, schema_for};
726use serde::de::DeserializeOwned;
727
728pub struct TypedPromptRequest<T, S, M, P>
745where
746 T: JsonSchema + DeserializeOwned + WasmCompatSend,
747 S: PromptType,
748 M: CompletionModel,
749 P: PromptHook<M>,
750{
751 inner: PromptRequest<S, M, P>,
752 _phantom: std::marker::PhantomData<T>,
753}
754
755impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
756where
757 T: JsonSchema + DeserializeOwned + WasmCompatSend,
758 M: CompletionModel,
759 P: PromptHook<M>,
760{
761 pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
765 let mut inner = PromptRequest::from_agent(agent, prompt);
766 inner.output_schema = Some(schema_for!(T));
768 Self {
769 inner,
770 _phantom: std::marker::PhantomData,
771 }
772 }
773}
774
775impl<T, S, M, P> TypedPromptRequest<T, S, M, P>
776where
777 T: JsonSchema + DeserializeOwned + WasmCompatSend,
778 S: PromptType,
779 M: CompletionModel,
780 P: PromptHook<M>,
781{
782 pub fn extended_details(self) -> TypedPromptRequest<T, Extended, M, P> {
788 TypedPromptRequest {
789 inner: self.inner.extended_details(),
790 _phantom: std::marker::PhantomData,
791 }
792 }
793
794 pub fn max_turns(mut self, depth: usize) -> Self {
800 self.inner = self.inner.max_turns(depth);
801 self
802 }
803
804 pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
808 self.inner = self.inner.with_tool_concurrency(concurrency);
809 self
810 }
811
812 pub fn with_history<I, H>(mut self, history: I) -> Self
814 where
815 I: IntoIterator<Item = H>,
816 H: Into<Message>,
817 {
818 self.inner = self.inner.with_history(history);
819 self
820 }
821
822 pub fn conversation(mut self, id: impl Into<String>) -> Self {
827 self.inner = self.inner.conversation(id);
828 self
829 }
830
831 pub fn without_memory(mut self) -> Self {
835 self.inner = self.inner.without_memory();
836 self
837 }
838
839 pub fn with_hook<P2>(self, hook: P2) -> TypedPromptRequest<T, S, M, P2>
843 where
844 P2: PromptHook<M>,
845 {
846 TypedPromptRequest {
847 inner: self.inner.with_hook(hook),
848 _phantom: std::marker::PhantomData,
849 }
850 }
851}
852
853impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
854where
855 T: JsonSchema + DeserializeOwned + WasmCompatSend,
856 M: CompletionModel,
857 P: PromptHook<M>,
858{
859 async fn send(self) -> Result<T, StructuredOutputError> {
861 let response = self.inner.send().await.map_err(Box::new)?;
862
863 if response.is_empty() {
864 return Err(StructuredOutputError::EmptyResponse);
865 }
866
867 let parsed: T = serde_json::from_str(&response)?;
868 Ok(parsed)
869 }
870}
871
872impl<T, M, P> TypedPromptRequest<T, Extended, M, P>
873where
874 T: JsonSchema + DeserializeOwned + WasmCompatSend,
875 M: CompletionModel,
876 P: PromptHook<M>,
877{
878 async fn send(self) -> Result<TypedPromptResponse<T>, StructuredOutputError> {
880 let response = self.inner.send().await.map_err(Box::new)?;
881
882 if response.output.is_empty() {
883 return Err(StructuredOutputError::EmptyResponse);
884 }
885
886 let parsed: T = serde_json::from_str(&response.output)?;
887 Ok(TypedPromptResponse::new(parsed, response.usage))
888 }
889}
890
891impl<T, M, P> IntoFuture for TypedPromptRequest<T, Standard, M, P>
892where
893 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
894 M: CompletionModel + 'static,
895 P: PromptHook<M> + 'static,
896{
897 type Output = Result<T, StructuredOutputError>;
898 type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
899
900 fn into_future(self) -> Self::IntoFuture {
901 Box::pin(self.send())
902 }
903}
904
905impl<T, M, P> IntoFuture for TypedPromptRequest<T, Extended, M, P>
906where
907 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
908 M: CompletionModel + 'static,
909 P: PromptHook<M> + 'static,
910{
911 type Output = Result<TypedPromptResponse<T>, StructuredOutputError>;
912 type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
913
914 fn into_future(self) -> Self::IntoFuture {
915 Box::pin(self.send())
916 }
917}
918
919#[cfg(test)]
920mod tests {
921 use super::TypedPromptResponse;
922 use crate::{
923 agent::AgentBuilder,
924 completion::{
925 AssistantContent, CompletionError, CompletionRequest, Message, Prompt, PromptError,
926 Usage,
927 },
928 message::UserContent,
929 test_utils::{
930 AppendFailingMemory, CountingMemory, FailingMemory, MockCompletionModel, MockTurn,
931 },
932 };
933 use serde::{Deserialize, Serialize};
934 use serde_json::json;
935
936 #[derive(Serialize)]
937 struct SerializeOnly {
938 value: &'static str,
939 }
940
941 #[derive(Deserialize)]
942 struct DeserializeOnly {
943 value: String,
944 }
945
946 #[test]
947 fn typed_prompt_response_serializes_with_serialize_only_output() {
948 let response = TypedPromptResponse::new(
949 SerializeOnly { value: "ok" },
950 Usage {
951 input_tokens: 1,
952 output_tokens: 2,
953 total_tokens: 3,
954 cached_input_tokens: 0,
955 cache_creation_input_tokens: 0,
956 reasoning_tokens: 0,
957 },
958 );
959
960 let json = serde_json::to_string(&response).expect("serialize typed prompt response");
961 assert!(json.contains("\"value\":\"ok\""));
962 }
963
964 #[test]
965 fn typed_prompt_response_deserializes_with_deserialize_only_output() {
966 let response: TypedPromptResponse<DeserializeOnly> = serde_json::from_str(
967 r#"{"output":{"value":"ok"},"usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3,"cached_input_tokens":0,"cache_creation_input_tokens":0,"reasoning_tokens":0}}"#,
968 )
969 .expect("deserialize typed prompt response");
970
971 assert_eq!(response.output.value, "ok");
972 assert_eq!(response.usage.input_tokens, 1);
973 assert_eq!(response.usage.output_tokens, 2);
974 assert_eq!(response.usage.total_tokens, 3);
975 }
976
977 fn validate_follow_up_tool_history(request: &CompletionRequest) {
978 let history = request.chat_history.iter().cloned().collect::<Vec<_>>();
979 assert_eq!(
980 history.len(),
981 3,
982 "follow-up request should contain the prompt, assistant tool call, and user tool result: {history:?}"
983 );
984
985 assert!(matches!(
986 history.first(),
987 Some(Message::User { content })
988 if matches!(
989 content.first(),
990 UserContent::Text(text) if text.text == "do tool work"
991 )
992 ));
993
994 assert!(matches!(
995 history.get(1),
996 Some(Message::Assistant { content, .. })
997 if matches!(
998 content.first(),
999 AssistantContent::ToolCall(tool_call)
1000 if tool_call.id == "tool_call_1"
1001 && tool_call.call_id.as_deref() == Some("call_1")
1002 )
1003 ));
1004
1005 assert!(matches!(
1006 history.get(2),
1007 Some(Message::User { content })
1008 if matches!(
1009 content.first(),
1010 UserContent::ToolResult(tool_result)
1011 if tool_result.id == "tool_call_1"
1012 && tool_result.call_id.as_deref() == Some("call_1")
1013 )
1014 ));
1015 }
1016
1017 #[tokio::test]
1018 async fn prompt_request_stops_cleanly_on_empty_terminal_turn() {
1019 let model = MockCompletionModel::new([
1020 MockTurn::tool_call("tool_call_1", "missing_tool", json!({"input": "value"}))
1021 .with_call_id("call_1")
1022 .with_usage(Usage {
1023 input_tokens: 1,
1024 output_tokens: 1,
1025 total_tokens: 2,
1026 cached_input_tokens: 0,
1027 cache_creation_input_tokens: 0,
1028 reasoning_tokens: 0,
1029 }),
1030 MockTurn::text("").with_usage(Usage {
1031 input_tokens: 1,
1032 output_tokens: 1,
1033 total_tokens: 2,
1034 cached_input_tokens: 0,
1035 cache_creation_input_tokens: 0,
1036 reasoning_tokens: 0,
1037 }),
1038 ]);
1039 let agent = AgentBuilder::new(model).build();
1040
1041 let response = agent
1042 .prompt("do tool work")
1043 .max_turns(3)
1044 .extended_details()
1045 .await
1046 .expect("empty terminal turn should not error");
1047
1048 assert!(response.output.is_empty());
1049 assert_eq!(
1050 response.usage,
1051 Usage {
1052 input_tokens: 2,
1053 output_tokens: 2,
1054 total_tokens: 4,
1055 cached_input_tokens: 0,
1056 cache_creation_input_tokens: 0,
1057 reasoning_tokens: 0,
1058 }
1059 );
1060
1061 let history = response
1062 .messages
1063 .expect("extended response should include history");
1064 assert_eq!(history.len(), 3);
1065 assert!(matches!(
1066 history.first(),
1067 Some(Message::User { content })
1068 if matches!(
1069 content.first(),
1070 UserContent::Text(text) if text.text == "do tool work"
1071 )
1072 ));
1073 assert!(history.iter().any(|message| matches!(
1074 message,
1075 Message::Assistant { content, .. }
1076 if matches!(
1077 content.first(),
1078 AssistantContent::ToolCall(tool_call)
1079 if tool_call.id == "tool_call_1"
1080 && tool_call.call_id.as_deref() == Some("call_1")
1081 )
1082 )));
1083 assert!(history.iter().any(|message| matches!(
1084 message,
1085 Message::User { content }
1086 if matches!(
1087 content.first(),
1088 UserContent::ToolResult(tool_result)
1089 if tool_result.id == "tool_call_1"
1090 && tool_result.call_id.as_deref() == Some("call_1")
1091 )
1092 )));
1093 assert!(!history.iter().any(|message| matches!(
1094 message,
1095 Message::Assistant { content, .. }
1096 if content.iter().any(|item| matches!(
1097 item,
1098 AssistantContent::Text(text) if text.text.is_empty()
1099 ))
1100 )));
1101 let requests = agent.model.requests();
1102 assert_eq!(requests.len(), 2);
1103 validate_follow_up_tool_history(&requests[1]);
1104 }
1105
1106 use crate::memory::{ConversationMemory, InMemoryConversationMemory};
1109
1110 #[tokio::test]
1111 async fn memory_loads_into_request_history() {
1112 let memory = InMemoryConversationMemory::new();
1113 memory
1114 .append(
1115 "thread-1",
1116 vec![Message::user("hello"), Message::assistant("hi there")],
1117 )
1118 .await
1119 .unwrap();
1120
1121 let model = MockCompletionModel::text("ack");
1122 let recorded = model.clone();
1123
1124 let agent = AgentBuilder::new(model).memory(memory).build();
1125 let _ = agent
1126 .prompt("ping")
1127 .conversation("thread-1")
1128 .await
1129 .expect("prompt should succeed");
1130
1131 let received = recorded.requests()[0]
1132 .chat_history
1133 .iter()
1134 .cloned()
1135 .collect::<Vec<_>>();
1136 assert_eq!(
1137 received.len(),
1138 3,
1139 "loaded memory (2) + current prompt should appear in request: {received:?}"
1140 );
1141 }
1142
1143 #[tokio::test]
1144 async fn memory_appends_full_turn_after_success() {
1145 let memory = InMemoryConversationMemory::new();
1146 let model = MockCompletionModel::text("ack");
1147 let agent = AgentBuilder::new(model).memory(memory.clone()).build();
1148
1149 let _ = agent
1150 .prompt("hello")
1151 .conversation("t1")
1152 .await
1153 .expect("prompt should succeed");
1154
1155 let stored = memory.load("t1").await.unwrap();
1156 assert_eq!(stored.len(), 2, "user prompt + assistant response saved");
1157 }
1158
1159 #[tokio::test]
1160 async fn explicit_with_history_overrides_memory() {
1161 let memory = CountingMemory::default();
1162 memory
1163 .inner()
1164 .append("t1", vec![Message::user("from-memory")])
1165 .await
1166 .unwrap();
1167
1168 let model = MockCompletionModel::text("ack");
1169 let recorded = model.clone();
1170
1171 let agent = AgentBuilder::new(model).memory(memory.clone()).build();
1172 let _ = agent
1173 .prompt("hello")
1174 .conversation("t1")
1175 .with_history(vec![Message::user("from-caller")])
1176 .await
1177 .expect("prompt should succeed");
1178
1179 assert_eq!(memory.load_count(), 0, "load skipped");
1180 let appends = memory.append_count();
1181 assert_eq!(appends, 0, "append skipped");
1182
1183 let received = recorded.requests()[0]
1184 .chat_history
1185 .iter()
1186 .cloned()
1187 .collect::<Vec<_>>();
1188 assert_eq!(received.len(), 2, "caller history (1) + current prompt");
1189 assert!(matches!(
1190 received.first(),
1191 Some(Message::User { content })
1192 if matches!(content.first(), UserContent::Text(t) if t.text == "from-caller")
1193 ));
1194 }
1195
1196 #[tokio::test]
1197 async fn memory_unchanged_on_provider_error() {
1198 let memory = InMemoryConversationMemory::new();
1199 let model = MockCompletionModel::new([MockTurn::error("boom")]);
1200
1201 let agent = AgentBuilder::new(model).memory(memory.clone()).build();
1202 let result = agent.prompt("hello").conversation("t1").await;
1203 assert!(result.is_err());
1204
1205 let stored = memory.load("t1").await.unwrap();
1206 assert!(stored.is_empty(), "no append on error");
1207 }
1208
1209 #[tokio::test]
1210 async fn missing_conversation_id_behaves_as_no_memory() {
1211 let memory = CountingMemory::default();
1212 let model = MockCompletionModel::text("ack");
1213 let agent = AgentBuilder::new(model).memory(memory.clone()).build();
1214
1215 let _ = agent.prompt("hello").await.expect("prompt should succeed");
1216
1217 assert_eq!(memory.load_count(), 0);
1218 assert_eq!(memory.append_count(), 0);
1219 }
1220
1221 #[tokio::test]
1222 async fn default_conversation_id_is_used_when_none_per_request() {
1223 let memory = InMemoryConversationMemory::new();
1224 let model = MockCompletionModel::text("ack");
1225 let agent = AgentBuilder::new(model)
1226 .memory(memory.clone())
1227 .conversation_id("default-thread")
1228 .build();
1229
1230 let _ = agent.prompt("hello").await.expect("prompt should succeed");
1231 let stored = memory.load("default-thread").await.unwrap();
1232 assert_eq!(stored.len(), 2);
1233 }
1234
1235 #[tokio::test]
1236 async fn with_filter_truncates_loaded_history() {
1237 let memory = InMemoryConversationMemory::new()
1238 .with_filter(|msgs: Vec<Message>| msgs.into_iter().rev().take(2).rev().collect());
1239 memory
1240 .append(
1241 "t1",
1242 vec![
1243 Message::user("1"),
1244 Message::assistant("2"),
1245 Message::user("3"),
1246 Message::assistant("4"),
1247 ],
1248 )
1249 .await
1250 .unwrap();
1251
1252 let model = MockCompletionModel::text("ack");
1253 let recorded = model.clone();
1254 let agent = AgentBuilder::new(model).memory(memory).build();
1255
1256 let _ = agent
1257 .prompt("ping")
1258 .conversation("t1")
1259 .await
1260 .expect("prompt should succeed");
1261
1262 let received = recorded.requests()[0]
1263 .chat_history
1264 .iter()
1265 .cloned()
1266 .collect::<Vec<_>>();
1267 assert_eq!(
1268 received.len(),
1269 3,
1270 "window-truncated history (2) + current prompt"
1271 );
1272 }
1273
1274 #[tokio::test]
1275 async fn without_memory_disables_for_request() {
1276 let memory = CountingMemory::default();
1277 let model = MockCompletionModel::text("ack");
1278 let agent = AgentBuilder::new(model)
1279 .memory(memory.clone())
1280 .conversation_id("t1")
1281 .build();
1282
1283 let _ = agent
1284 .prompt("hello")
1285 .without_memory()
1286 .await
1287 .expect("prompt should succeed");
1288
1289 assert_eq!(memory.load_count(), 0);
1290 assert_eq!(memory.append_count(), 0);
1291 }
1292
1293 #[tokio::test]
1294 async fn memory_load_error_surfaces_as_prompt_error() {
1295 let model = MockCompletionModel::text("ack");
1296 let agent = AgentBuilder::new(model)
1297 .memory(FailingMemory::default())
1298 .build();
1299 let result = agent.prompt("hello").conversation("t1").await;
1300
1301 match result {
1302 Err(PromptError::CompletionError(CompletionError::RequestError(err))) => {
1303 let msg = format!("{err}");
1304 assert!(msg.contains("load boom"), "got: {msg}");
1305 }
1306 other => panic!("expected PromptError::CompletionError(RequestError), got {other:?}"),
1307 }
1308 }
1309
1310 #[tokio::test]
1311 async fn memory_append_error_does_not_drop_response() {
1312 let model = MockCompletionModel::text("ack");
1313 let agent = AgentBuilder::new(model)
1314 .memory(AppendFailingMemory::default())
1315 .build();
1316 let response: String = agent
1317 .prompt("hello")
1318 .conversation("t1")
1319 .await
1320 .expect("append failure must not block successful completion");
1321
1322 assert!(!response.is_empty());
1323 }
1324}