1pub mod hooks;
2pub mod streaming;
3
4use super::{
5 Agent,
6 completion::{DynamicContextStore, build_completion_request},
7};
8use crate::{
9 OneOrMany,
10 completion::{CompletionModel, Document, Message, PromptError, Usage},
11 json_utils,
12 message::{AssistantContent, ToolChoice, ToolResultContent, UserContent},
13 tool::server::ToolServerHandle,
14 wasm_compat::{WasmBoxedFuture, WasmCompatSend},
15};
16use futures::{StreamExt, stream};
17use hooks::{HookAction, PromptHook, ToolCallHookAction};
18use serde::{Deserialize, Serialize};
19use std::{
20 future::IntoFuture,
21 marker::PhantomData,
22 sync::{
23 Arc,
24 atomic::{AtomicU64, Ordering},
25 },
26};
27use tracing::info_span;
28use tracing::{Instrument, span::Id};
29
30pub trait PromptType {}
31pub struct Standard;
32pub struct Extended;
33
34impl PromptType for Standard {}
35impl PromptType for Extended {}
36
37pub struct PromptRequest<S, M, P>
46where
47 S: PromptType,
48 M: CompletionModel,
49 P: PromptHook<M>,
50{
51 prompt: Message,
53 chat_history: Option<Vec<Message>>,
55 max_turns: usize,
57
58 model: Arc<M>,
61 agent_name: Option<String>,
63 preamble: Option<String>,
65 static_context: Vec<Document>,
67 temperature: Option<f64>,
69 max_tokens: Option<u64>,
71 additional_params: Option<serde_json::Value>,
73 tool_server_handle: ToolServerHandle,
75 dynamic_context: DynamicContextStore,
77 tool_choice: Option<ToolChoice>,
79
80 state: PhantomData<S>,
82 hook: Option<P>,
84 concurrency: usize,
86 output_schema: Option<schemars::Schema>,
88}
89
90impl<M, P> PromptRequest<Standard, M, P>
91where
92 M: CompletionModel,
93 P: PromptHook<M>,
94{
95 pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
97 PromptRequest {
98 prompt: prompt.into(),
99 chat_history: None,
100 max_turns: agent.default_max_turns.unwrap_or_default(),
101 model: agent.model.clone(),
102 agent_name: agent.name.clone(),
103 preamble: agent.preamble.clone(),
104 static_context: agent.static_context.clone(),
105 temperature: agent.temperature,
106 max_tokens: agent.max_tokens,
107 additional_params: agent.additional_params.clone(),
108 tool_server_handle: agent.tool_server_handle.clone(),
109 dynamic_context: agent.dynamic_context.clone(),
110 tool_choice: agent.tool_choice.clone(),
111 state: PhantomData,
112 hook: agent.hook.clone(),
113 concurrency: 1,
114 output_schema: agent.output_schema.clone(),
115 }
116 }
117}
118
119impl<S, M, P> PromptRequest<S, M, P>
120where
121 S: PromptType,
122 M: CompletionModel,
123 P: PromptHook<M>,
124{
125 pub fn extended_details(self) -> PromptRequest<Extended, M, P> {
132 PromptRequest {
133 prompt: self.prompt,
134 chat_history: self.chat_history,
135 max_turns: self.max_turns,
136 model: self.model,
137 agent_name: self.agent_name,
138 preamble: self.preamble,
139 static_context: self.static_context,
140 temperature: self.temperature,
141 max_tokens: self.max_tokens,
142 additional_params: self.additional_params,
143 tool_server_handle: self.tool_server_handle,
144 dynamic_context: self.dynamic_context,
145 tool_choice: self.tool_choice,
146 state: PhantomData,
147 hook: self.hook,
148 concurrency: self.concurrency,
149 output_schema: self.output_schema,
150 }
151 }
152
153 pub fn max_turns(mut self, depth: usize) -> Self {
156 self.max_turns = depth;
157 self
158 }
159
160 pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
163 self.concurrency = concurrency;
164 self
165 }
166
167 pub fn with_history<I, T>(mut self, history: I) -> Self
169 where
170 I: IntoIterator<Item = T>,
171 T: Into<Message>,
172 {
173 self.chat_history = Some(history.into_iter().map(Into::into).collect());
174 self
175 }
176
177 pub fn with_hook<P2>(self, hook: P2) -> PromptRequest<S, M, P2>
180 where
181 P2: PromptHook<M>,
182 {
183 PromptRequest {
184 prompt: self.prompt,
185 chat_history: self.chat_history,
186 max_turns: self.max_turns,
187 model: self.model,
188 agent_name: self.agent_name,
189 preamble: self.preamble,
190 static_context: self.static_context,
191 temperature: self.temperature,
192 max_tokens: self.max_tokens,
193 additional_params: self.additional_params,
194 tool_server_handle: self.tool_server_handle,
195 dynamic_context: self.dynamic_context,
196 tool_choice: self.tool_choice,
197 state: PhantomData,
198 hook: Some(hook),
199 concurrency: self.concurrency,
200 output_schema: self.output_schema,
201 }
202 }
203}
204
205impl<M, P> IntoFuture for PromptRequest<Standard, M, P>
209where
210 M: CompletionModel + 'static,
211 P: PromptHook<M> + 'static,
212{
213 type Output = Result<String, PromptError>;
214 type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
215
216 fn into_future(self) -> Self::IntoFuture {
217 Box::pin(self.send())
218 }
219}
220
221impl<M, P> IntoFuture for PromptRequest<Extended, M, P>
222where
223 M: CompletionModel + 'static,
224 P: PromptHook<M> + 'static,
225{
226 type Output = Result<PromptResponse, PromptError>;
227 type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
228
229 fn into_future(self) -> Self::IntoFuture {
230 Box::pin(self.send())
231 }
232}
233
234impl<M, P> PromptRequest<Standard, M, P>
235where
236 M: CompletionModel,
237 P: PromptHook<M>,
238{
239 async fn send(self) -> Result<String, PromptError> {
240 self.extended_details().send().await.map(|resp| resp.output)
241 }
242}
243
244#[derive(Debug, Clone, Serialize, Deserialize)]
245#[non_exhaustive]
246pub struct PromptResponse {
247 pub output: String,
248 pub usage: Usage,
249 pub messages: Option<Vec<Message>>,
250}
251
252impl std::fmt::Display for PromptResponse {
253 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254 self.output.fmt(f)
255 }
256}
257
258impl PromptResponse {
259 pub fn new(output: impl Into<String>, usage: Usage) -> Self {
260 Self {
261 output: output.into(),
262 usage,
263 messages: None,
264 }
265 }
266
267 pub fn with_messages(mut self, messages: Vec<Message>) -> Self {
268 self.messages = Some(messages);
269 self
270 }
271}
272
273#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct TypedPromptResponse<T> {
275 pub output: T,
276 pub usage: Usage,
277}
278
279impl<T> TypedPromptResponse<T> {
280 pub fn new(output: T, usage: Usage) -> Self {
281 Self { output, usage }
282 }
283}
284
285const UNKNOWN_AGENT_NAME: &str = "Unnamed Agent";
286
287fn build_history_for_request(
289 chat_history: Option<&[Message]>,
290 new_messages: &[Message],
291) -> Vec<Message> {
292 let input = chat_history.unwrap_or(&[]);
293 input.iter().chain(new_messages.iter()).cloned().collect()
294}
295
296fn build_full_history(
298 chat_history: Option<&[Message]>,
299 new_messages: Vec<Message>,
300) -> Vec<Message> {
301 let input = chat_history.unwrap_or(&[]);
302 input.iter().cloned().chain(new_messages).collect()
303}
304
305fn is_empty_assistant_turn(choice: &OneOrMany<AssistantContent>) -> bool {
306 choice.len() == 1
307 && matches!(
308 choice.first(),
309 AssistantContent::Text(text) if text.text.is_empty()
310 )
311}
312
313impl<M, P> PromptRequest<Extended, M, P>
314where
315 M: CompletionModel,
316 P: PromptHook<M>,
317{
318 fn agent_name(&self) -> &str {
319 self.agent_name.as_deref().unwrap_or(UNKNOWN_AGENT_NAME)
320 }
321
322 async fn send(self) -> Result<PromptResponse, PromptError> {
323 let agent_span = if tracing::Span::current().is_disabled() {
324 info_span!(
325 "invoke_agent",
326 gen_ai.operation.name = "invoke_agent",
327 gen_ai.agent.name = self.agent_name(),
328 gen_ai.system_instructions = self.preamble,
329 gen_ai.prompt = tracing::field::Empty,
330 gen_ai.completion = tracing::field::Empty,
331 gen_ai.usage.input_tokens = tracing::field::Empty,
332 gen_ai.usage.output_tokens = tracing::field::Empty,
333 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
334 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
335 )
336 } else {
337 tracing::Span::current()
338 };
339
340 if let Some(text) = self.prompt.rag_text() {
341 agent_span.record("gen_ai.prompt", text);
342 }
343
344 let agent_name_for_span = self.agent_name.clone();
345 let chat_history = self.chat_history;
346 let mut new_messages: Vec<Message> = vec![self.prompt.clone()];
347
348 let mut current_max_turns = 0;
349 let mut usage = Usage::new();
350 let current_span_id: AtomicU64 = AtomicU64::new(0);
351
352 let last_prompt = loop {
354 let Some((prompt_ref, history_for_current_turn)) = new_messages.split_last() else {
356 return Err(PromptError::prompt_cancelled(
357 build_full_history(chat_history.as_deref(), new_messages),
358 "prompt loop lost its pending prompt",
359 ));
360 };
361 let prompt = prompt_ref.clone();
362
363 if current_max_turns > self.max_turns + 1 {
364 break prompt;
365 }
366
367 current_max_turns += 1;
368
369 if self.max_turns > 1 {
370 tracing::info!(
371 "Current conversation depth: {}/{}",
372 current_max_turns,
373 self.max_turns
374 );
375 }
376
377 let history_for_hook =
379 build_history_for_request(chat_history.as_deref(), history_for_current_turn);
380
381 if let Some(ref hook) = self.hook
382 && let HookAction::Terminate { reason } =
383 hook.on_completion_call(&prompt, &history_for_hook).await
384 {
385 return Err(PromptError::prompt_cancelled(
386 build_full_history(chat_history.as_deref(), new_messages),
387 reason,
388 ));
389 }
390
391 let span = tracing::Span::current();
392 let chat_span = info_span!(
393 target: "rig::agent_chat",
394 parent: &span,
395 "chat",
396 gen_ai.operation.name = "chat",
397 gen_ai.agent.name = agent_name_for_span.as_deref().unwrap_or(UNKNOWN_AGENT_NAME),
398 gen_ai.system_instructions = self.preamble,
399 gen_ai.provider.name = tracing::field::Empty,
400 gen_ai.request.model = tracing::field::Empty,
401 gen_ai.response.id = tracing::field::Empty,
402 gen_ai.response.model = tracing::field::Empty,
403 gen_ai.usage.output_tokens = tracing::field::Empty,
404 gen_ai.usage.input_tokens = tracing::field::Empty,
405 gen_ai.usage.cache_read.input_tokens = tracing::field::Empty,
406 gen_ai.usage.cache_creation.input_tokens = tracing::field::Empty,
407 gen_ai.input.messages = tracing::field::Empty,
408 gen_ai.output.messages = tracing::field::Empty,
409 );
410
411 let chat_span = if current_span_id.load(Ordering::SeqCst) != 0 {
412 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
413 chat_span.follows_from(id).to_owned()
414 } else {
415 chat_span
416 };
417
418 if let Some(id) = chat_span.id() {
419 current_span_id.store(id.into_u64(), Ordering::SeqCst);
420 };
421
422 let history_for_request =
424 build_history_for_request(chat_history.as_deref(), history_for_current_turn);
425
426 let resp = build_completion_request(
427 &self.model,
428 prompt.clone(),
429 &history_for_request,
430 self.preamble.as_deref(),
431 &self.static_context,
432 self.temperature,
433 self.max_tokens,
434 self.additional_params.as_ref(),
435 self.tool_choice.as_ref(),
436 &self.tool_server_handle,
437 &self.dynamic_context,
438 self.output_schema.as_ref(),
439 )
440 .await?
441 .send()
442 .instrument(chat_span.clone())
443 .await?;
444
445 usage += resp.usage;
446
447 if let Some(ref hook) = self.hook
448 && let HookAction::Terminate { reason } =
449 hook.on_completion_response(&prompt, &resp).await
450 {
451 return Err(PromptError::prompt_cancelled(
452 build_full_history(chat_history.as_deref(), new_messages),
453 reason,
454 ));
455 }
456
457 let (tool_calls, texts): (Vec<_>, Vec<_>) = resp
458 .choice
459 .iter()
460 .partition(|choice| matches!(choice, AssistantContent::ToolCall(_)));
461
462 if !is_empty_assistant_turn(&resp.choice) {
466 new_messages.push(Message::Assistant {
467 id: resp.message_id.clone(),
468 content: resp.choice.clone(),
469 });
470 }
471
472 if tool_calls.is_empty() {
473 let merged_texts = texts
474 .into_iter()
475 .filter_map(|content| {
476 if let AssistantContent::Text(text) = content {
477 Some(text.text.clone())
478 } else {
479 None
480 }
481 })
482 .collect::<Vec<_>>()
483 .join("\n");
484
485 if self.max_turns > 1 {
486 tracing::info!("Depth reached: {}/{}", current_max_turns, self.max_turns);
487 }
488
489 agent_span.record("gen_ai.completion", &merged_texts);
490 agent_span.record("gen_ai.usage.input_tokens", usage.input_tokens);
491 agent_span.record("gen_ai.usage.output_tokens", usage.output_tokens);
492 agent_span.record(
493 "gen_ai.usage.cache_read.input_tokens",
494 usage.cached_input_tokens,
495 );
496 agent_span.record(
497 "gen_ai.usage.cache_creation.input_tokens",
498 usage.cache_creation_input_tokens,
499 );
500
501 return Ok(PromptResponse::new(merged_texts, usage).with_messages(new_messages));
502 }
503
504 let hook = self.hook.clone();
505 let tool_server_handle = self.tool_server_handle.clone();
506
507 let full_history_for_errors =
509 build_full_history(chat_history.as_deref(), new_messages.clone());
510
511 let tool_calls: Vec<AssistantContent> = tool_calls.into_iter().cloned().collect();
512 let tool_content = stream::iter(tool_calls)
513 .map(|choice| {
514 let hook1 = hook.clone();
515 let hook2 = hook.clone();
516 let tool_server_handle = tool_server_handle.clone();
517
518 let tool_span = info_span!(
519 "execute_tool",
520 gen_ai.operation.name = "execute_tool",
521 gen_ai.tool.type = "function",
522 gen_ai.tool.name = tracing::field::Empty,
523 gen_ai.tool.call.id = tracing::field::Empty,
524 gen_ai.tool.call.arguments = tracing::field::Empty,
525 gen_ai.tool.call.result = tracing::field::Empty
526 );
527
528 let tool_span = if current_span_id.load(Ordering::SeqCst) != 0 {
529 let id = Id::from_u64(current_span_id.load(Ordering::SeqCst));
530 tool_span.follows_from(id).to_owned()
531 } else {
532 tool_span
533 };
534
535 if let Some(id) = tool_span.id() {
536 current_span_id.store(id.into_u64(), Ordering::SeqCst);
537 };
538
539 let cloned_history_for_error = full_history_for_errors.clone();
541
542 async move {
543 if let AssistantContent::ToolCall(tool_call) = choice {
544 let tool_name = &tool_call.function.name;
545 let args =
546 json_utils::value_to_json_string(&tool_call.function.arguments);
547 let internal_call_id = nanoid::nanoid!();
548 let tool_span = tracing::Span::current();
549 tool_span.record("gen_ai.tool.name", tool_name);
550 tool_span.record("gen_ai.tool.call.id", &tool_call.id);
551 tool_span.record("gen_ai.tool.call.arguments", &args);
552 if let Some(hook) = hook1 {
553 let action = hook
554 .on_tool_call(
555 tool_name,
556 tool_call.call_id.clone(),
557 &internal_call_id,
558 &args,
559 )
560 .await;
561
562 if let ToolCallHookAction::Terminate { reason } = action {
563 return Err(PromptError::prompt_cancelled(
564 cloned_history_for_error,
565 reason,
566 ));
567 }
568
569 if let ToolCallHookAction::Skip { reason } = action {
570 tracing::info!(
572 tool_name = tool_name,
573 reason = reason,
574 "Tool call rejected"
575 );
576 if let Some(call_id) = tool_call.call_id.clone() {
577 return Ok(UserContent::tool_result_with_call_id(
578 tool_call.id.clone(),
579 call_id,
580 OneOrMany::one(reason.into()),
581 ));
582 } else {
583 return Ok(UserContent::tool_result(
584 tool_call.id.clone(),
585 OneOrMany::one(reason.into()),
586 ));
587 }
588 }
589 }
590 let output = match tool_server_handle.call_tool(tool_name, &args).await
591 {
592 Ok(res) => res,
593 Err(e) => {
594 tracing::warn!("Error while executing tool: {e}");
595 e.to_string()
596 }
597 };
598 if let Some(hook) = hook2
599 && let HookAction::Terminate { reason } = hook
600 .on_tool_result(
601 tool_name,
602 tool_call.call_id.clone(),
603 &internal_call_id,
604 &args,
605 &output.to_string(),
606 )
607 .await
608 {
609 return Err(PromptError::prompt_cancelled(
610 cloned_history_for_error,
611 reason,
612 ));
613 }
614
615 tool_span.record("gen_ai.tool.call.result", &output);
616 tracing::info!(
617 "executed tool {tool_name} with args {args}. result: {output}"
618 );
619 if let Some(call_id) = tool_call.call_id.clone() {
620 Ok(UserContent::tool_result_with_call_id(
621 tool_call.id.clone(),
622 call_id,
623 ToolResultContent::from_tool_output(output),
624 ))
625 } else {
626 Ok(UserContent::tool_result(
627 tool_call.id.clone(),
628 ToolResultContent::from_tool_output(output),
629 ))
630 }
631 } else {
632 Err(PromptError::prompt_cancelled(
633 Vec::new(),
634 "tool execution received non-tool assistant content",
635 ))
636 }
637 }
638 .instrument(tool_span)
639 })
640 .buffer_unordered(self.concurrency)
641 .collect::<Vec<Result<UserContent, PromptError>>>()
642 .await
643 .into_iter()
644 .collect::<Result<Vec<_>, _>>()?;
645
646 let Some(content) = OneOrMany::from_iter_optional(tool_content) else {
647 return Err(PromptError::prompt_cancelled(
648 build_full_history(chat_history.as_deref(), new_messages),
649 "tool execution produced no tool results",
650 ));
651 };
652
653 new_messages.push(Message::User { content });
654 };
655
656 Err(PromptError::MaxTurnsError {
658 max_turns: self.max_turns,
659 chat_history: build_full_history(chat_history.as_deref(), new_messages).into(),
660 prompt: last_prompt.into(),
661 })
662 }
663}
664
665use crate::completion::StructuredOutputError;
670use schemars::{JsonSchema, schema_for};
671use serde::de::DeserializeOwned;
672
673pub struct TypedPromptRequest<T, S, M, P>
690where
691 T: JsonSchema + DeserializeOwned + WasmCompatSend,
692 S: PromptType,
693 M: CompletionModel,
694 P: PromptHook<M>,
695{
696 inner: PromptRequest<S, M, P>,
697 _phantom: std::marker::PhantomData<T>,
698}
699
700impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
701where
702 T: JsonSchema + DeserializeOwned + WasmCompatSend,
703 M: CompletionModel,
704 P: PromptHook<M>,
705{
706 pub fn from_agent(agent: &Agent<M, P>, prompt: impl Into<Message>) -> Self {
710 let mut inner = PromptRequest::from_agent(agent, prompt);
711 inner.output_schema = Some(schema_for!(T));
713 Self {
714 inner,
715 _phantom: std::marker::PhantomData,
716 }
717 }
718}
719
720impl<T, S, M, P> TypedPromptRequest<T, S, M, P>
721where
722 T: JsonSchema + DeserializeOwned + WasmCompatSend,
723 S: PromptType,
724 M: CompletionModel,
725 P: PromptHook<M>,
726{
727 pub fn extended_details(self) -> TypedPromptRequest<T, Extended, M, P> {
733 TypedPromptRequest {
734 inner: self.inner.extended_details(),
735 _phantom: std::marker::PhantomData,
736 }
737 }
738
739 pub fn max_turns(mut self, depth: usize) -> Self {
745 self.inner = self.inner.max_turns(depth);
746 self
747 }
748
749 pub fn with_tool_concurrency(mut self, concurrency: usize) -> Self {
753 self.inner = self.inner.with_tool_concurrency(concurrency);
754 self
755 }
756
757 pub fn with_history<I, H>(mut self, history: I) -> Self
759 where
760 I: IntoIterator<Item = H>,
761 H: Into<Message>,
762 {
763 self.inner = self.inner.with_history(history);
764 self
765 }
766
767 pub fn with_hook<P2>(self, hook: P2) -> TypedPromptRequest<T, S, M, P2>
771 where
772 P2: PromptHook<M>,
773 {
774 TypedPromptRequest {
775 inner: self.inner.with_hook(hook),
776 _phantom: std::marker::PhantomData,
777 }
778 }
779}
780
781impl<T, M, P> TypedPromptRequest<T, Standard, M, P>
782where
783 T: JsonSchema + DeserializeOwned + WasmCompatSend,
784 M: CompletionModel,
785 P: PromptHook<M>,
786{
787 async fn send(self) -> Result<T, StructuredOutputError> {
789 let response = self.inner.send().await.map_err(Box::new)?;
790
791 if response.is_empty() {
792 return Err(StructuredOutputError::EmptyResponse);
793 }
794
795 let parsed: T = serde_json::from_str(&response)?;
796 Ok(parsed)
797 }
798}
799
800impl<T, M, P> TypedPromptRequest<T, Extended, M, P>
801where
802 T: JsonSchema + DeserializeOwned + WasmCompatSend,
803 M: CompletionModel,
804 P: PromptHook<M>,
805{
806 async fn send(self) -> Result<TypedPromptResponse<T>, StructuredOutputError> {
808 let response = self.inner.send().await.map_err(Box::new)?;
809
810 if response.output.is_empty() {
811 return Err(StructuredOutputError::EmptyResponse);
812 }
813
814 let parsed: T = serde_json::from_str(&response.output)?;
815 Ok(TypedPromptResponse::new(parsed, response.usage))
816 }
817}
818
819impl<T, M, P> IntoFuture for TypedPromptRequest<T, Standard, M, P>
820where
821 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
822 M: CompletionModel + 'static,
823 P: PromptHook<M> + 'static,
824{
825 type Output = Result<T, StructuredOutputError>;
826 type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
827
828 fn into_future(self) -> Self::IntoFuture {
829 Box::pin(self.send())
830 }
831}
832
833impl<T, M, P> IntoFuture for TypedPromptRequest<T, Extended, M, P>
834where
835 T: JsonSchema + DeserializeOwned + WasmCompatSend + 'static,
836 M: CompletionModel + 'static,
837 P: PromptHook<M> + 'static,
838{
839 type Output = Result<TypedPromptResponse<T>, StructuredOutputError>;
840 type IntoFuture = WasmBoxedFuture<'static, Self::Output>;
841
842 fn into_future(self) -> Self::IntoFuture {
843 Box::pin(self.send())
844 }
845}
846
847#[cfg(test)]
848mod tests {
849 use super::TypedPromptResponse;
850 use crate::{
851 OneOrMany,
852 agent::AgentBuilder,
853 completion::{
854 AssistantContent, CompletionError, CompletionModel, CompletionRequest,
855 CompletionResponse, Message, Prompt, Usage,
856 },
857 message::UserContent,
858 streaming::StreamingCompletionResponse,
859 };
860 use serde::{Deserialize, Serialize};
861 use serde_json::json;
862 use std::sync::{
863 Arc,
864 atomic::{AtomicUsize, Ordering},
865 };
866
867 #[derive(Serialize)]
868 struct SerializeOnly {
869 value: &'static str,
870 }
871
872 #[derive(Deserialize)]
873 struct DeserializeOnly {
874 value: String,
875 }
876
877 #[test]
878 fn typed_prompt_response_serializes_with_serialize_only_output() {
879 let response = TypedPromptResponse::new(
880 SerializeOnly { value: "ok" },
881 Usage {
882 input_tokens: 1,
883 output_tokens: 2,
884 total_tokens: 3,
885 cached_input_tokens: 0,
886 cache_creation_input_tokens: 0,
887 },
888 );
889
890 let json = serde_json::to_string(&response).expect("serialize typed prompt response");
891 assert!(json.contains("\"value\":\"ok\""));
892 }
893
894 #[test]
895 fn typed_prompt_response_deserializes_with_deserialize_only_output() {
896 let response: TypedPromptResponse<DeserializeOnly> = serde_json::from_str(
897 r#"{"output":{"value":"ok"},"usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3,"cached_input_tokens":0,"cache_creation_input_tokens":0}}"#,
898 )
899 .expect("deserialize typed prompt response");
900
901 assert_eq!(response.output.value, "ok");
902 assert_eq!(response.usage.input_tokens, 1);
903 assert_eq!(response.usage.output_tokens, 2);
904 assert_eq!(response.usage.total_tokens, 3);
905 }
906
907 fn validate_follow_up_tool_history(request: &CompletionRequest) {
908 let history = request.chat_history.iter().cloned().collect::<Vec<_>>();
909 assert_eq!(
910 history.len(),
911 3,
912 "follow-up request should contain the prompt, assistant tool call, and user tool result: {history:?}"
913 );
914
915 assert!(matches!(
916 history.first(),
917 Some(Message::User { content })
918 if matches!(
919 content.first(),
920 UserContent::Text(text) if text.text == "do tool work"
921 )
922 ));
923
924 assert!(matches!(
925 history.get(1),
926 Some(Message::Assistant { content, .. })
927 if matches!(
928 content.first(),
929 AssistantContent::ToolCall(tool_call)
930 if tool_call.id == "tool_call_1"
931 && tool_call.call_id.as_deref() == Some("call_1")
932 )
933 ));
934
935 assert!(matches!(
936 history.get(2),
937 Some(Message::User { content })
938 if matches!(
939 content.first(),
940 UserContent::ToolResult(tool_result)
941 if tool_result.id == "tool_call_1"
942 && tool_result.call_id.as_deref() == Some("call_1")
943 )
944 ));
945 }
946
947 #[derive(Clone, Default)]
948 struct EmptyFinalTurnMockModel {
949 turn_counter: Arc<AtomicUsize>,
950 }
951
952 #[allow(refining_impl_trait)]
953 impl CompletionModel for EmptyFinalTurnMockModel {
954 type Response = ();
955 type StreamingResponse = ();
956 type Client = ();
957
958 fn make(_: &Self::Client, _: impl Into<String>) -> Self {
959 Self::default()
960 }
961
962 async fn completion(
963 &self,
964 request: CompletionRequest,
965 ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
966 let turn = self.turn_counter.fetch_add(1, Ordering::SeqCst);
967
968 let choice = if turn == 0 {
969 OneOrMany::one(AssistantContent::tool_call_with_call_id(
970 "tool_call_1",
971 "call_1".to_string(),
972 "missing_tool",
973 json!({"input": "value"}),
974 ))
975 } else {
976 validate_follow_up_tool_history(&request);
977 OneOrMany::one(AssistantContent::text(""))
978 };
979
980 Ok(CompletionResponse {
981 choice,
982 usage: Usage {
983 input_tokens: 1,
984 output_tokens: 1,
985 total_tokens: 2,
986 cached_input_tokens: 0,
987 cache_creation_input_tokens: 0,
988 },
989 raw_response: (),
990 message_id: None,
991 })
992 }
993
994 async fn stream(
995 &self,
996 _request: CompletionRequest,
997 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
998 Err(CompletionError::ProviderError(
999 "stream is unused in this non-streaming test".to_string(),
1000 ))
1001 }
1002 }
1003
1004 #[tokio::test]
1005 async fn prompt_request_stops_cleanly_on_empty_terminal_turn() {
1006 let model = EmptyFinalTurnMockModel::default();
1007 let turn_counter = model.turn_counter.clone();
1008 let agent = AgentBuilder::new(model).build();
1009
1010 let response = agent
1011 .prompt("do tool work")
1012 .max_turns(3)
1013 .extended_details()
1014 .await
1015 .expect("empty terminal turn should not error");
1016
1017 assert!(response.output.is_empty());
1018 assert_eq!(
1019 response.usage,
1020 Usage {
1021 input_tokens: 2,
1022 output_tokens: 2,
1023 total_tokens: 4,
1024 cached_input_tokens: 0,
1025 cache_creation_input_tokens: 0,
1026 }
1027 );
1028
1029 let history = response
1030 .messages
1031 .expect("extended response should include history");
1032 assert_eq!(history.len(), 3);
1033 assert!(matches!(
1034 history.first(),
1035 Some(Message::User { content })
1036 if matches!(
1037 content.first(),
1038 UserContent::Text(text) if text.text == "do tool work"
1039 )
1040 ));
1041 assert!(history.iter().any(|message| matches!(
1042 message,
1043 Message::Assistant { content, .. }
1044 if matches!(
1045 content.first(),
1046 AssistantContent::ToolCall(tool_call)
1047 if tool_call.id == "tool_call_1"
1048 && tool_call.call_id.as_deref() == Some("call_1")
1049 )
1050 )));
1051 assert!(history.iter().any(|message| matches!(
1052 message,
1053 Message::User { content }
1054 if matches!(
1055 content.first(),
1056 UserContent::ToolResult(tool_result)
1057 if tool_result.id == "tool_call_1"
1058 && tool_result.call_id.as_deref() == Some("call_1")
1059 )
1060 )));
1061 assert!(!history.iter().any(|message| matches!(
1062 message,
1063 Message::Assistant { content, .. }
1064 if content.iter().any(|item| matches!(
1065 item,
1066 AssistantContent::Text(text) if text.text.is_empty()
1067 ))
1068 )));
1069 assert_eq!(turn_counter.load(Ordering::SeqCst), 2);
1070 }
1071}