1use std::collections::HashMap;
10use std::pin::Pin;
11use std::sync::Arc;
12
13use async_trait::async_trait;
14use futures::Stream;
15use serde::{Deserialize, Serialize};
16use serde_json::Value;
17
18use super::base::{BaseLanguageModel, LangSmithParams, LanguageModelConfig, LanguageModelInput};
19use super::model_profile::ModelProfile;
20use crate::GenerationType;
21use crate::callbacks::{
22 AsyncCallbackManagerForLLMRun, BaseCallbackHandler, CallbackManagerForLLMRun, Callbacks,
23};
24use crate::error::{Error, Result};
25use crate::messages::{AIMessage, AIMessageChunk, BaseMessage, ChunkPosition, UsageMetadata};
26use crate::outputs::{ChatGeneration, ChatGenerationChunk, ChatResult, Generation, LLMResult};
27use crate::rate_limiters::BaseRateLimiter;
28use crate::tools::{BaseTool, ToolDefinition};
29
30pub type ChatStream = Pin<Box<dyn Stream<Item = Result<ChatChunk>> + Send>>;
32
33pub type ChatGenerationStream = Pin<Box<dyn Stream<Item = Result<ChatGenerationChunk>> + Send>>;
35
36pub type AIMessageChunkStream = Pin<Box<dyn Stream<Item = Result<AIMessageChunk>> + Send>>;
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ChatChunk {
45 pub content: String,
47 pub is_final: bool,
49 #[serde(skip_serializing_if = "Option::is_none")]
51 pub usage_metadata: Option<UsageMetadata>,
52 #[serde(skip_serializing_if = "Option::is_none")]
55 pub finish_reason: Option<String>,
56}
57
58impl ChatChunk {
59 pub fn new(content: impl Into<String>) -> Self {
61 Self {
62 content: content.into(),
63 is_final: false,
64 usage_metadata: None,
65 finish_reason: None,
66 }
67 }
68
69 pub fn final_chunk(
71 usage_metadata: Option<UsageMetadata>,
72 finish_reason: Option<String>,
73 ) -> Self {
74 Self {
75 content: String::new(),
76 is_final: true,
77 usage_metadata,
78 finish_reason,
79 }
80 }
81
82 pub fn with_usage_metadata(mut self, usage: UsageMetadata) -> Self {
84 self.usage_metadata = Some(usage);
85 self
86 }
87
88 pub fn with_finish_reason(mut self, reason: impl Into<String>) -> Self {
90 self.finish_reason = Some(reason.into());
91 self
92 }
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
99#[serde(untagged)]
100pub enum ToolChoice {
101 String(String),
103 Structured {
105 #[serde(rename = "type")]
107 choice_type: String,
108 #[serde(skip_serializing_if = "Option::is_none")]
110 name: Option<String>,
111 },
112}
113
114impl ToolChoice {
115 pub fn auto() -> Self {
117 ToolChoice::String("auto".to_string())
118 }
119
120 pub fn any() -> Self {
122 ToolChoice::String("any".to_string())
123 }
124
125 pub fn none() -> Self {
127 ToolChoice::String("none".to_string())
128 }
129
130 pub fn tool(name: impl Into<String>) -> Self {
132 ToolChoice::Structured {
133 choice_type: "tool".to_string(),
134 name: Some(name.into()),
135 }
136 }
137}
138
139#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
143#[serde(untagged)]
144pub enum DisableStreaming {
145 Bool(bool),
147 ToolCalling,
149}
150
151impl Default for DisableStreaming {
152 fn default() -> Self {
153 DisableStreaming::Bool(false)
154 }
155}
156
157impl DisableStreaming {
158 pub fn should_disable(&self, has_tools: bool) -> bool {
164 match self {
165 DisableStreaming::Bool(b) => *b,
166 DisableStreaming::ToolCalling => has_tools,
167 }
168 }
169}
170
171impl From<bool> for DisableStreaming {
172 fn from(b: bool) -> Self {
173 DisableStreaming::Bool(b)
174 }
175}
176
177#[derive(Clone, Default)]
179pub struct ChatModelConfig {
180 pub base: LanguageModelConfig,
182
183 pub rate_limiter: Option<Arc<dyn BaseRateLimiter>>,
185
186 pub disable_streaming: DisableStreaming,
194
195 pub output_version: Option<String>,
202
203 pub profile: Option<ModelProfile>,
205}
206
207impl std::fmt::Debug for ChatModelConfig {
208 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
209 f.debug_struct("ChatModelConfig")
210 .field("base", &self.base)
211 .field(
212 "rate_limiter",
213 &self.rate_limiter.as_ref().map(|_| "<rate_limiter>"),
214 )
215 .field("disable_streaming", &self.disable_streaming)
216 .field("output_version", &self.output_version)
217 .field("profile", &self.profile)
218 .finish()
219 }
220}
221
222impl ChatModelConfig {
223 pub fn new() -> Self {
225 Self::default()
226 }
227
228 pub fn with_rate_limiter(mut self, rate_limiter: Arc<dyn BaseRateLimiter>) -> Self {
230 self.rate_limiter = Some(rate_limiter);
231 self
232 }
233
234 pub fn with_disable_streaming(mut self, disable: impl Into<DisableStreaming>) -> Self {
236 self.disable_streaming = disable.into();
237 self
238 }
239
240 pub fn with_output_version(mut self, version: impl Into<String>) -> Self {
242 self.output_version = Some(version.into());
243 self
244 }
245
246 pub fn with_profile(mut self, profile: ModelProfile) -> Self {
248 self.profile = Some(profile);
249 self
250 }
251
252 pub fn with_cache(mut self, cache: bool) -> Self {
254 self.base.cache = Some(cache);
255 self
256 }
257
258 pub fn with_verbose(mut self, verbose: bool) -> Self {
260 self.base.verbose = verbose;
261 self
262 }
263
264 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
266 self.base.tags = Some(tags);
267 self
268 }
269
270 pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
272 self.base.metadata = Some(metadata);
273 self
274 }
275}
276
277#[async_trait]
296pub trait BaseChatModel: BaseLanguageModel {
297 fn chat_config(&self) -> &ChatModelConfig;
299
300 fn profile(&self) -> Option<&ModelProfile> {
302 self.chat_config().profile.as_ref()
303 }
304
305 async fn _generate(
319 &self,
320 messages: Vec<BaseMessage>,
321 stop: Option<Vec<String>>,
322 run_manager: Option<&CallbackManagerForLLMRun>,
323 ) -> Result<ChatResult>;
324
325 async fn _agenerate(
329 &self,
330 messages: Vec<BaseMessage>,
331 stop: Option<Vec<String>>,
332 _run_manager: Option<&AsyncCallbackManagerForLLMRun>,
333 ) -> Result<ChatResult> {
334 self._generate(messages, stop, None).await
335 }
336
337 fn _stream(
351 &self,
352 _messages: Vec<BaseMessage>,
353 _stop: Option<Vec<String>>,
354 _run_manager: Option<&CallbackManagerForLLMRun>,
355 ) -> Result<ChatGenerationStream> {
356 Err(Error::NotImplemented("Streaming not implemented".into()))
357 }
358
359 async fn _astream(
363 &self,
364 messages: Vec<BaseMessage>,
365 stop: Option<Vec<String>>,
366 _run_manager: Option<&AsyncCallbackManagerForLLMRun>,
367 ) -> Result<ChatGenerationStream> {
368 self._stream(messages, stop, None)
369 }
370
371 fn get_first_message(&self, result: &ChatResult) -> Result<AIMessage> {
375 if result.generations.is_empty() {
376 return Err(Error::Other("No generations returned".into()));
377 }
378
379 match result.generations[0].message.clone() {
380 BaseMessage::AI(message) => Ok(message),
381 other => Ok(AIMessage::new(other.content())),
382 }
383 }
384
385 fn _combine_llm_outputs(
393 &self,
394 _llm_outputs: &[Option<HashMap<String, Value>>],
395 ) -> HashMap<String, Value> {
396 HashMap::new()
397 }
398
399 fn _convert_cached_generations(&self, cache_val: Vec<Generation>) -> Vec<ChatGeneration> {
405 cache_val
406 .into_iter()
407 .map(|cached_gen| {
408 let message = AIMessage::new(&cached_gen.text);
410 match cached_gen.generation_info {
411 Some(info) => ChatGeneration::with_info(message.into(), info),
412 None => ChatGeneration::new(message.into()),
413 }
414 })
415 .collect()
416 }
417
418 fn _get_invocation_params(
422 &self,
423 stop: Option<&[String]>,
424 kwargs: Option<&HashMap<String, Value>>,
425 ) -> HashMap<String, Value> {
426 let mut params = self.get_identifying_params();
427 if let Some(stop) = stop {
428 params.insert(
429 "stop".to_string(),
430 Value::Array(stop.iter().map(|s| Value::String(s.clone())).collect()),
431 );
432 }
433 if let Some(kw) = kwargs {
434 params.extend(kw.clone());
435 }
436 params
437 }
438
439 fn _get_llm_string(
443 &self,
444 stop: Option<&[String]>,
445 kwargs: Option<&HashMap<String, Value>>,
446 ) -> String {
447 let params = self._get_invocation_params(stop, kwargs);
448
449 let mut sorted_items: Vec<_> = params.iter().collect();
451 sorted_items.sort_by_key(|(k, _)| *k);
452
453 format!("{:?}", sorted_items)
454 }
455
456 fn has_stream_impl(&self) -> bool {
461 false
462 }
463
464 fn has_astream_impl(&self) -> bool {
469 false
470 }
471
472 fn has_streaming_field(&self) -> Option<bool> {
476 None
477 }
478
479 fn _should_stream(
499 &self,
500 async_api: bool,
501 has_tools: bool,
502 stream_kwarg: Option<bool>,
503 run_manager: Option<&[Arc<dyn BaseCallbackHandler>]>,
504 ) -> bool {
505 let sync_not_implemented = !self.has_stream_impl();
507 let async_not_implemented = !self.has_astream_impl();
508
509 if !async_api && sync_not_implemented {
511 return false;
512 }
513 if async_api && async_not_implemented && sync_not_implemented {
515 return false;
516 }
517
518 if self
520 .chat_config()
521 .disable_streaming
522 .should_disable(has_tools)
523 {
524 return false;
525 }
526
527 if let Some(stream) = stream_kwarg {
529 return stream;
530 }
531
532 if let Some(streaming) = self.has_streaming_field() {
534 return streaming;
535 }
536
537 if let Some(handlers) = run_manager {
539 if !handlers.is_empty() {
543 return true;
545 }
546 }
547
548 true
550 }
551
552 async fn generate(
571 &self,
572 messages: Vec<Vec<BaseMessage>>,
573 stop: Option<Vec<String>>,
574 _callbacks: Option<Callbacks>,
575 ) -> Result<LLMResult> {
576 let mut all_generations: Vec<Vec<GenerationType>> = Vec::new();
577
578 for message_list in messages {
579 let result = self._generate(message_list, stop.clone(), None).await?;
580 all_generations.push(result.generations.into_iter().map(|e| e.into()).collect());
581 }
582
583 Ok(LLMResult::new(all_generations))
584 }
585
586 async fn agenerate(
588 &self,
589 messages: Vec<Vec<BaseMessage>>,
590 stop: Option<Vec<String>>,
591 _callbacks: Option<Callbacks>,
592 ) -> Result<LLMResult> {
593 let mut all_generations: Vec<Vec<GenerationType>> = Vec::new();
594
595 for message_list in messages {
596 let result = self._agenerate(message_list, stop.clone(), None).await?;
597 all_generations.push(result.generations.into_iter().map(|e| e.into()).collect());
598 }
599
600 Ok(LLMResult::new(all_generations))
601 }
602
603 async fn _call_async(
607 &self,
608 messages: Vec<BaseMessage>,
609 stop: Option<Vec<String>>,
610 callbacks: Option<Callbacks>,
611 ) -> Result<BaseMessage> {
612 let result = self.agenerate(vec![messages], stop, callbacks).await?;
613
614 if result.generations.is_empty() || result.generations[0].is_empty() {
615 return Err(Error::Other("No generations returned".into()));
616 }
617
618 match &result.generations[0][0] {
619 GenerationType::ChatGeneration(chat_gen) => Ok(chat_gen.message.clone()),
620 _ => Err(Error::Other("Unexpected generation type".into())),
621 }
622 }
623
624 async fn generate_with_tools(
640 &self,
641 messages: Vec<BaseMessage>,
642 _tools: &[ToolDefinition],
643 _tool_choice: Option<&ToolChoice>,
644 stop: Option<Vec<String>>,
645 ) -> Result<AIMessage> {
646 let result = self._generate(messages, stop, None).await?;
647
648 if result.generations.is_empty() {
649 return Err(Error::Other("No generations returned".into()));
650 }
651
652 match result.generations[0].message.clone() {
653 BaseMessage::AI(message) => Ok(message),
654 _ => Err(Error::Other("Unexpected message type".into())),
655 }
656 }
657
658 fn convert_input(&self, input: LanguageModelInput) -> Result<Vec<BaseMessage>> {
660 Ok(input.to_messages())
661 }
662
663 async fn invoke(&self, input: LanguageModelInput) -> Result<AIMessage> {
665 let messages = self.convert_input(input)?;
666 let result = self._generate(messages, None, None).await?;
667
668 if result.generations.is_empty() {
669 return Err(Error::Other("No generations returned".into()));
670 }
671
672 match result.generations[0].message.clone() {
673 BaseMessage::AI(message) => Ok(message),
674 _ => Err(Error::Other("Unexpected message type".into())),
675 }
676 }
677
678 async fn ainvoke(&self, input: LanguageModelInput) -> Result<AIMessage> {
680 let messages = self.convert_input(input)?;
681 let result = self._agenerate(messages, None, None).await?;
682
683 if result.generations.is_empty() {
684 return Err(Error::Other("No generations returned".into()));
685 }
686
687 match result.generations[0].message.clone() {
688 BaseMessage::AI(message) => Ok(message),
689 _ => Err(Error::Other("Unexpected message type".into())),
690 }
691 }
692
693 fn bind_tools(
711 &self,
712 _tools: &[Arc<dyn BaseTool>],
713 _tool_choice: Option<ToolChoice>,
714 ) -> Result<()> {
715 Err(Error::NotImplemented(
716 "bind_tools is not implemented for this model".into(),
717 ))
718 }
719
720 fn get_tool_definitions(&self, tools: &[Arc<dyn BaseTool>]) -> Vec<ToolDefinition> {
724 tools.iter().map(|t| t.definition()).collect()
725 }
726
727 async fn stream(
741 &self,
742 input: LanguageModelInput,
743 stop: Option<Vec<String>>,
744 ) -> Result<AIMessageChunkStream> {
745 let messages = self.convert_input(input)?;
746 let has_tools = false;
747
748 if !self._should_stream(false, has_tools, Some(true), None) {
750 let result = self._generate(messages, stop, None).await?;
753 let message = self.get_first_message(&result)?;
754 let chunk = AIMessageChunk::new(message.content());
755 return Ok(Box::pin(futures::stream::once(async move { Ok(chunk) })));
756 }
757
758 if let Some(ref rate_limiter) = self.chat_config().rate_limiter {
760 rate_limiter.acquire(true);
761 }
762
763 let generation_stream = self._stream(messages, stop, None)?;
765
766 let chunk_stream = async_stream::stream! {
768 use futures::StreamExt;
769
770 let mut pinned_stream = generation_stream;
771 let mut yielded = false;
772
773 while let Some(result) = pinned_stream.next().await {
774 match result {
775 Ok(generation_chunk) => {
776 let ai_chunk = match generation_chunk.message {
778 BaseMessage::AI(ai_msg) => AIMessageChunk::new(ai_msg.content()),
779 other => AIMessageChunk::new(other.content()),
780 };
781 yielded = true;
782 yield Ok(ai_chunk);
783 }
784 Err(e) => {
785 yield Err(e);
786 return;
787 }
788 }
789 }
790
791 if yielded {
793 let mut final_chunk = AIMessageChunk::new("");
794 final_chunk.set_chunk_position(Some(ChunkPosition::Last));
795 yield Ok(final_chunk);
796 }
797 };
798
799 Ok(Box::pin(chunk_stream))
800 }
801
802 async fn astream(
816 &self,
817 input: LanguageModelInput,
818 stop: Option<Vec<String>>,
819 ) -> Result<AIMessageChunkStream> {
820 let messages = self.convert_input(input)?;
821 let has_tools = false;
822
823 if !self._should_stream(true, has_tools, Some(true), None) {
825 let result = self._agenerate(messages, stop, None).await?;
827 let message = self.get_first_message(&result)?;
828 let chunk = AIMessageChunk::new(message.content());
829 return Ok(Box::pin(futures::stream::once(async move { Ok(chunk) })));
830 }
831
832 if let Some(ref rate_limiter) = self.chat_config().rate_limiter {
834 rate_limiter.aacquire(true).await;
835 }
836
837 let generation_stream = self._astream(messages, stop, None).await?;
839
840 let chunk_stream = async_stream::stream! {
842 use futures::StreamExt;
843
844 let mut pinned_stream = generation_stream;
845 let mut yielded = false;
846
847 while let Some(result) = pinned_stream.next().await {
848 match result {
849 Ok(generation_chunk) => {
850 let ai_chunk = match generation_chunk.message {
852 BaseMessage::AI(ai_msg) => AIMessageChunk::new(ai_msg.content()),
853 other => AIMessageChunk::new(other.content()),
854 };
855 yielded = true;
856 yield Ok(ai_chunk);
857 }
858 Err(e) => {
859 yield Err(e);
860 return;
861 }
862 }
863 }
864
865 if yielded {
867 let mut final_chunk = AIMessageChunk::new("");
868 final_chunk.set_chunk_position(Some(ChunkPosition::Last));
869 yield Ok(final_chunk);
870 }
871 };
872
873 Ok(Box::pin(chunk_stream))
874 }
875
876 async fn stream_generations(
891 &self,
892 messages: Vec<BaseMessage>,
893 stop: Option<Vec<String>>,
894 run_manager: Option<&CallbackManagerForLLMRun>,
895 ) -> Result<ChatGenerationStream> {
896 let has_tools = false;
897
898 if !self._should_stream(false, has_tools, None, None) {
900 let result = self._generate(messages, stop, run_manager).await?;
902 if result.generations.is_empty() {
903 return Err(Error::Other("No generations returned".into()));
904 }
905
906 let message = result.generations[0].message.clone();
907 let chunk = ChatGenerationChunk::new(message);
908 return Ok(Box::pin(futures::stream::once(async move { Ok(chunk) })));
909 }
910
911 self._stream(messages, stop, run_manager)
913 }
914
915 fn get_chat_ls_params(&self, stop: Option<&[String]>) -> LangSmithParams {
917 let mut params = self.get_ls_params(stop);
918 params.ls_model_type = Some("chat".to_string());
919 params
920 }
921
922 fn to_dict(&self) -> HashMap<String, Value> {
926 let mut result = self.get_identifying_params();
927 result.insert(
928 "_type".to_string(),
929 Value::String(self.llm_type().to_string()),
930 );
931 result
932 }
933
934 fn with_structured_output(&self, _schema: Value, _include_raw: bool) -> Result<()> {
953 Err(Error::NotImplemented(
954 "with_structured_output is not implemented for this model".into(),
955 ))
956 }
957
958 fn get_identifying_params(&self) -> HashMap<String, Value> {
962 let mut params = HashMap::new();
963 params.insert(
964 "_type".to_string(),
965 Value::String(self.llm_type().to_string()),
966 );
967 params.insert(
968 "model".to_string(),
969 Value::String(self.model_name().to_string()),
970 );
971 params
972 }
973}
974
975#[async_trait]
980pub trait SimpleChatModel: BaseChatModel {
981 async fn _call(
985 &self,
986 messages: Vec<BaseMessage>,
987 stop: Option<Vec<String>>,
988 run_manager: Option<&CallbackManagerForLLMRun>,
989 ) -> Result<String>;
990}
991
992#[async_trait]
993impl<T: SimpleChatModel> BaseChatModel for T {
994 fn chat_config(&self) -> &ChatModelConfig {
995 <T as BaseChatModel>::chat_config(self)
996 }
997
998 async fn _generate(
999 &self,
1000 messages: Vec<BaseMessage>,
1001 stop: Option<Vec<String>>,
1002 run_manager: Option<&CallbackManagerForLLMRun>,
1003 ) -> Result<ChatResult> {
1004 let output_str = self._call(messages, stop, run_manager).await?;
1005 let message = AIMessage::new(output_str);
1006 let generation = ChatGeneration::new(message.into());
1007 Ok(ChatResult::new(vec![generation]))
1008 }
1009}
1010
1011pub fn generate_from_stream<I>(mut stream: I) -> Result<ChatResult>
1029where
1030 I: Iterator<Item = ChatGenerationChunk>,
1031{
1032 let first = stream.next();
1033 if first.is_none() {
1034 return Err(Error::Other("No generations found in stream.".into()));
1035 }
1036
1037 let mut generation = first.unwrap();
1038
1039 for chunk in stream {
1041 generation = generation + chunk;
1042 }
1043
1044 let chat_generation: ChatGeneration = generation.into();
1046 Ok(ChatResult::new(vec![chat_generation]))
1047}
1048
1049pub async fn agenerate_from_stream(
1067 stream: impl futures::Stream<Item = Result<ChatGenerationChunk>> + Unpin,
1068) -> Result<ChatResult> {
1069 use futures::StreamExt;
1070
1071 let chunks: Vec<ChatGenerationChunk> = stream
1072 .filter_map(|result| async { result.ok() })
1073 .collect()
1074 .await;
1075
1076 if chunks.is_empty() {
1077 return Err(Error::Other("No generations found in stream.".into()));
1078 }
1079
1080 generate_from_stream(chunks.into_iter())
1081}
1082
1083pub async fn collect_and_merge_stream(
1096 mut stream: impl futures::StreamExt<Item = Result<ChatGenerationChunk>> + Unpin,
1097) -> Result<Option<ChatGenerationChunk>> {
1098 let mut chunks = Vec::new();
1099 while let Some(chunk_result) = stream.next().await {
1100 chunks.push(chunk_result?);
1101 }
1102
1103 if chunks.is_empty() {
1104 return Ok(None);
1105 }
1106
1107 Ok(crate::outputs::merge_chat_generation_chunks(chunks))
1108}
1109
1110#[cfg(test)]
1111mod tests {
1112 use super::*;
1113
1114 #[test]
1115 fn test_chat_model_config_builder() {
1116 let config = ChatModelConfig::new()
1117 .with_cache(true)
1118 .with_verbose(true)
1119 .with_disable_streaming(true)
1120 .with_output_version("v1");
1121
1122 assert_eq!(config.base.cache, Some(true));
1123 assert!(config.base.verbose);
1124 assert_eq!(config.disable_streaming, DisableStreaming::Bool(true));
1125 assert_eq!(config.output_version, Some("v1".to_string()));
1126 }
1127
1128 #[test]
1129 fn test_tool_choice_auto() {
1130 let choice = ToolChoice::auto();
1131 assert_eq!(choice, ToolChoice::String("auto".to_string()));
1132 }
1133
1134 #[test]
1135 fn test_tool_choice_any() {
1136 let choice = ToolChoice::any();
1137 assert_eq!(choice, ToolChoice::String("any".to_string()));
1138 }
1139
1140 #[test]
1141 fn test_tool_choice_none() {
1142 let choice = ToolChoice::none();
1143 assert_eq!(choice, ToolChoice::String("none".to_string()));
1144 }
1145
1146 #[test]
1147 fn test_tool_choice_tool() {
1148 let choice = ToolChoice::tool("my_tool");
1149 assert_eq!(
1150 choice,
1151 ToolChoice::Structured {
1152 choice_type: "tool".to_string(),
1153 name: Some("my_tool".to_string()),
1154 }
1155 );
1156 }
1157
1158 #[test]
1159 fn test_tool_choice_serialization() {
1160 let auto = ToolChoice::auto();
1161 let json = serde_json::to_string(&auto).unwrap();
1162 assert_eq!(json, "\"auto\"");
1163
1164 let tool = ToolChoice::tool("my_tool");
1165 let json = serde_json::to_string(&tool).unwrap();
1166 assert!(json.contains("my_tool"));
1167 assert!(json.contains("tool"));
1168 }
1169
1170 #[test]
1171 fn test_disable_streaming() {
1172 let bool_false = DisableStreaming::Bool(false);
1173 assert!(!bool_false.should_disable(true));
1174 assert!(!bool_false.should_disable(false));
1175
1176 let bool_true = DisableStreaming::Bool(true);
1177 assert!(bool_true.should_disable(true));
1178 assert!(bool_true.should_disable(false));
1179
1180 let tool_calling = DisableStreaming::ToolCalling;
1181 assert!(tool_calling.should_disable(true));
1182 assert!(!tool_calling.should_disable(false));
1183 }
1184
1185 #[test]
1186 fn test_generate_from_stream() {
1187 let chunks = vec![
1188 ChatGenerationChunk::new(AIMessage::new("Hello, ").into()),
1189 ChatGenerationChunk::new(AIMessage::new("world!").into()),
1190 ];
1191
1192 let result = generate_from_stream(chunks.into_iter()).unwrap();
1193 assert_eq!(result.generations.len(), 1);
1194 assert_eq!(result.generations[0].message.content(), "Hello, world!");
1195 }
1196
1197 #[test]
1198 fn test_generate_from_stream_empty() {
1199 let chunks: Vec<ChatGenerationChunk> = vec![];
1200 let result = generate_from_stream(chunks.into_iter());
1201 assert!(result.is_err());
1202 }
1203
1204 #[tokio::test]
1205 async fn test_agenerate_from_stream() {
1206 let chunks = vec![
1207 Ok(ChatGenerationChunk::new(AIMessage::new("Hello, ").into())),
1208 Ok(ChatGenerationChunk::new(AIMessage::new("world!").into())),
1209 ];
1210
1211 let stream = futures::stream::iter(chunks);
1212 let result = agenerate_from_stream(stream).await.unwrap();
1213 assert_eq!(result.generations.len(), 1);
1214 assert_eq!(result.generations[0].message.content(), "Hello, world!");
1215 }
1216
1217 #[tokio::test]
1218 async fn test_collect_and_merge_stream() {
1219 let chunks = vec![
1220 Ok(ChatGenerationChunk::new(AIMessage::new("a").into())),
1221 Ok(ChatGenerationChunk::new(AIMessage::new("b").into())),
1222 Ok(ChatGenerationChunk::new(AIMessage::new("c").into())),
1223 ];
1224
1225 let stream = futures::stream::iter(chunks);
1226 let merged = collect_and_merge_stream(stream).await.unwrap();
1227
1228 assert!(merged.is_some());
1229 assert_eq!(merged.unwrap().text, "abc");
1230 }
1231
1232 #[tokio::test]
1233 async fn test_collect_and_merge_stream_empty() {
1234 let chunks: Vec<Result<ChatGenerationChunk>> = vec![];
1235 let stream = futures::stream::iter(chunks);
1236 let merged = collect_and_merge_stream(stream).await.unwrap();
1237 assert!(merged.is_none());
1238 }
1239}