neuromance_common/
client.rs

1use std::collections::HashMap;
2use std::fmt;
3use std::str::FromStr;
4use std::sync::Arc;
5use std::time::Duration;
6
7use chrono::{DateTime, Utc};
8use secrecy::SecretString;
9use serde::{Deserialize, Serialize};
10
11use crate::chat::Message;
12use crate::tools::Tool;
13
14/// Controls how the model selects which tool to call, if any.
15///
16/// This enum provides fine-grained control over tool selection behavior,
17/// from fully automatic to forcing a specific function call.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19#[non_exhaustive]
20pub enum ToolChoice {
21    /// Let the model automatically decide whether to call a tool and which one.
22    ///
23    /// This is the default behavior for most use cases.
24    #[serde(rename = "auto")]
25    Auto,
26    /// Disable tool calling for this request.
27    ///
28    /// The model will not call any tools and will respond directly.
29    #[serde(rename = "none")]
30    None,
31    /// Require the model to call at least one tool.
32    ///
33    /// The model must call a tool rather than responding directly.
34    #[serde(rename = "required")]
35    Required,
36    /// Force the model to call a specific function by name.
37    ///
38    /// # Example
39    ///
40    /// ```
41    /// use neuromance_common::ToolChoice;
42    ///
43    /// let choice = ToolChoice::Function {
44    ///     name: "get_weather".to_string(),
45    /// };
46    /// ```
47    Function {
48        /// The name of the function to call
49        name: String,
50    },
51}
52
53impl fmt::Display for ToolChoice {
54    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
55        match self {
56            ToolChoice::Auto => write!(f, "auto"),
57            ToolChoice::None => write!(f, "none"),
58            ToolChoice::Required => write!(f, "required"),
59            ToolChoice::Function { name } => write!(f, "{}", name),
60        }
61    }
62}
63
64impl From<ToolChoice> for serde_json::Value {
65    fn from(tool_choice: ToolChoice) -> Self {
66        match tool_choice {
67            ToolChoice::Auto => serde_json::Value::String("auto".to_string()),
68            ToolChoice::None => serde_json::Value::String("none".to_string()),
69            ToolChoice::Required => serde_json::Value::String("required".to_string()),
70            ToolChoice::Function { name } => {
71                serde_json::json!({
72                    "type": "function",
73                    "function": {
74                        "name": name
75                    }
76                })
77            }
78        }
79    }
80}
81
82/// Indicates why the model stopped generating tokens.
83///
84/// This enum provides information about whether generation completed naturally,
85/// was truncated, or was interrupted for another reason.
86#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Copy)]
87#[non_exhaustive]
88pub enum FinishReason {
89    /// Generation completed naturally at a stop sequence or end of response.
90    ///
91    /// This is the most common finish reason for successful completions.
92    #[serde(rename = "stop")]
93    Stop,
94    /// Generation was truncated because the maximum token limit was reached.
95    ///
96    /// Consider increasing `max_tokens` if the response appears incomplete.
97    #[serde(rename = "length")]
98    Length,
99    /// Generation stopped because the model requested tool calls.
100    ///
101    /// The response contains tool calls that should be executed.
102    #[serde(rename = "tool_calls")]
103    ToolCalls,
104    /// Generation was stopped by the content filter.
105    ///
106    /// The response may have been blocked due to safety policies.
107    #[serde(rename = "content_filter")]
108    ContentFilter,
109    /// Generation failed due to a model error.
110    ///
111    /// This typically indicates an internal error in the model.
112    #[serde(rename = "model_error")]
113    ModelError,
114}
115
116impl fmt::Display for FinishReason {
117    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118        match self {
119            FinishReason::Stop => write!(f, "stop"),
120            FinishReason::Length => write!(f, "length"),
121            FinishReason::ToolCalls => write!(f, "tool_calls"),
122            FinishReason::ContentFilter => write!(f, "content_filter"),
123            FinishReason::ModelError => write!(f, "model_error"),
124        }
125    }
126}
127
128impl FromStr for FinishReason {
129    type Err = anyhow::Error;
130
131    fn from_str(s: &str) -> Result<Self, Self::Err> {
132        match s {
133            "stop" => Ok(FinishReason::Stop),
134            "length" => Ok(FinishReason::Length),
135            "tool_calls" => Ok(FinishReason::ToolCalls),
136            "content_filter" => Ok(FinishReason::ContentFilter),
137            "model_error" => Ok(FinishReason::ModelError),
138            _ => anyhow::bail!("Unknown finish reason: {}", s),
139        }
140    }
141}
142
143/// Configuration for exponential backoff retry behavior.
144///
145/// This struct controls how failed requests are retried with increasing delays
146/// between attempts. Supports optional jitter to avoid thundering herd problems.
147///
148/// # Examples
149///
150/// ```
151/// use std::time::Duration;
152/// use neuromance_common::client::RetryConfig;
153///
154/// // Conservative retry policy
155/// let config = RetryConfig {
156///     max_retries: 5,
157///     initial_delay: Duration::from_millis(500),
158///     max_delay: Duration::from_secs(60),
159///     backoff_multiplier: 2.0,
160///     jitter: true,
161/// };
162/// ```
163#[derive(Debug, Clone)]
164pub struct RetryConfig {
165    /// Maximum number of retry attempts before failing.
166    pub max_retries: usize,
167    /// Initial delay before the first retry attempt.
168    pub initial_delay: Duration,
169    /// Maximum delay between retry attempts (caps exponential growth).
170    pub max_delay: Duration,
171    /// Multiplier for exponential backoff (typically 2.0 for doubling).
172    pub backoff_multiplier: f64,
173    /// Whether to add random jitter to retry delays to prevent thundering herd.
174    pub jitter: bool,
175}
176
177impl Default for RetryConfig {
178    fn default() -> Self {
179        Self {
180            max_retries: 3,
181            initial_delay: Duration::from_millis(1000),
182            max_delay: Duration::from_secs(30),
183            backoff_multiplier: 2.0,
184            jitter: true,
185        }
186    }
187}
188
189/// Token usage statistics for a completion request.
190///
191/// Tracks the number of tokens consumed by the prompt and completion,
192/// along with optional cost information and detailed breakdowns.
193///
194/// # Note
195///
196/// Different providers may count tokens differently. The `total_tokens`
197/// should always equal `prompt_tokens + completion_tokens`.
198#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct Usage {
200    /// Number of tokens in the input prompt.
201    #[serde(alias = "input_tokens")]
202    pub prompt_tokens: u32,
203    /// Number of tokens generated in the completion.
204    #[serde(alias = "output_tokens")]
205    pub completion_tokens: u32,
206    /// Total tokens used (prompt + completion).
207    pub total_tokens: u32,
208    /// Estimated cost in USD for this request (if available).
209    pub cost: Option<f64>,
210    /// Detailed breakdown of input token usage.
211    #[serde(skip_serializing_if = "Option::is_none")]
212    pub input_tokens_details: Option<InputTokensDetails>,
213    /// Detailed breakdown of output token usage.
214    #[serde(skip_serializing_if = "Option::is_none")]
215    pub output_tokens_details: Option<OutputTokensDetails>,
216}
217
218/// Detailed breakdown of input token usage.
219///
220/// Provides additional information about how input tokens were processed,
221/// including cache utilization.
222#[derive(Debug, Clone, Serialize, Deserialize)]
223pub struct InputTokensDetails {
224    /// Number of tokens served from cache rather than processed fresh.
225    ///
226    /// Cached tokens are typically cheaper and faster to process.
227    pub cached_tokens: u32,
228}
229
230/// Detailed breakdown of output token usage.
231///
232/// Provides additional information about token usage in the model's response,
233/// including reasoning tokens for models that support chain-of-thought.
234#[derive(Debug, Clone, Serialize, Deserialize)]
235pub struct OutputTokensDetails {
236    /// Number of tokens used for internal reasoning (e.g., chain-of-thought).
237    ///
238    /// Some models generate reasoning tokens that are not part of the final response.
239    pub reasoning_tokens: u32,
240}
241
242/// A request for a chat completion from an LLM.
243///
244/// This struct encapsulates all parameters needed to request a completion,
245/// including conversation history, sampling parameters, and tool configuration.
246///
247/// # Examples
248///
249/// ```
250/// use neuromance_common::{ChatRequest, Message, MessageRole};
251/// use uuid::Uuid;
252///
253/// let message = Message::new(Uuid::new_v4(), MessageRole::User, "Hello!");
254/// let request = ChatRequest::new(vec![message])
255///     .with_model("gpt-4")
256///     .with_temperature(0.7)
257///     .with_max_tokens(1000);
258/// ```
259#[derive(Debug, Clone, Serialize, Deserialize)]
260pub struct ChatRequest {
261    /// The conversation messages to send to the model.
262    pub messages: Arc<[Message]>,
263    /// The model identifier to use for generation.
264    pub model: Option<String>,
265    /// Sampling temperature controlling randomness (0.0 to 2.0).
266    ///
267    /// Lower values make output more focused and deterministic.
268    pub temperature: Option<f32>,
269    /// Maximum number of tokens to generate in the response.
270    pub max_tokens: Option<u32>,
271    /// Nucleus sampling threshold (0.0 to 1.0).
272    ///
273    /// Only tokens comprising the top `top_p` probability mass are considered.
274    pub top_p: Option<f32>,
275    /// Penalty for token frequency in the output (-2.0 to 2.0).
276    ///
277    /// Positive values decrease likelihood of repeating tokens.
278    pub frequency_penalty: Option<f32>,
279    /// Penalty for token presence in the output (-2.0 to 2.0).
280    ///
281    /// Positive values encourage discussing new topics.
282    pub presence_penalty: Option<f32>,
283    /// Sequences that will stop generation when encountered.
284    pub stop: Option<Vec<String>>,
285    /// Tools available for the model to call.
286    pub tools: Option<Vec<Tool>>,
287    /// Strategy for tool selection.
288    pub tool_choice: Option<ToolChoice>,
289    /// Whether to stream the response incrementally.
290    pub stream: bool,
291    /// End-user identifier for tracking and abuse prevention.
292    pub user: Option<String>,
293    /// Whether to enable thinking mode (vendor-specific, e.g., Qwen models).
294    pub enable_thinking: Option<bool>,
295    /// Additional metadata to attach to this request.
296    pub metadata: HashMap<String, serde_json::Value>,
297}
298
299impl fmt::Display for ChatRequest {
300    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
301        match serde_json::to_string(self) {
302            Ok(json) => write!(f, "{}", json),
303            Err(_) => write!(f, "Error serializing ChatRequest to JSON"),
304        }
305    }
306}
307
308impl ChatRequest {
309    /// Creates a new chat request with the given messages.
310    ///
311    /// All optional parameters are set to `None` and must be configured
312    /// using the builder methods.
313    ///
314    /// # Arguments
315    ///
316    /// * `messages` - The conversation history to send to the model
317    ///
318    /// # Examples
319    ///
320    /// ```
321    /// use neuromance_common::{ChatRequest, Message, MessageRole};
322    /// use uuid::Uuid;
323    ///
324    /// let msg = Message::new(Uuid::new_v4(), MessageRole::User, "Hello!");
325    /// let request = ChatRequest::new(vec![msg]);
326    /// ```
327    pub fn new(messages: impl Into<Arc<[Message]>>) -> Self {
328        Self {
329            messages: messages.into(),
330            model: None,
331            temperature: None,
332            max_tokens: None,
333            top_p: None,
334            frequency_penalty: None,
335            presence_penalty: None,
336            stop: None,
337            tools: None,
338            tool_choice: None,
339            stream: false,
340            user: None,
341            enable_thinking: None,
342            metadata: HashMap::new(),
343        }
344    }
345}
346
347impl From<(&Config, Vec<Message>)> for ChatRequest {
348    fn from((config, messages): (&Config, Vec<Message>)) -> Self {
349        Self {
350            messages: messages.into(),
351            model: Some(config.model.clone()),
352            temperature: config.temperature,
353            max_tokens: config.max_tokens,
354            top_p: config.top_p,
355            frequency_penalty: config.frequency_penalty,
356            presence_penalty: config.presence_penalty,
357            stop: config.stop_sequences.clone(),
358            tools: None,
359            tool_choice: None,
360            stream: false,
361            user: None,
362            enable_thinking: None,
363            metadata: HashMap::new(),
364        }
365    }
366}
367
368impl From<(&Config, Arc<[Message]>)> for ChatRequest {
369    fn from((config, messages): (&Config, Arc<[Message]>)) -> Self {
370        Self {
371            messages,
372            model: Some(config.model.clone()),
373            temperature: config.temperature,
374            max_tokens: config.max_tokens,
375            top_p: config.top_p,
376            frequency_penalty: config.frequency_penalty,
377            presence_penalty: config.presence_penalty,
378            stop: config.stop_sequences.clone(),
379            tools: None,
380            tool_choice: None,
381            stream: false,
382            user: None,
383            enable_thinking: None,
384            metadata: HashMap::new(),
385        }
386    }
387}
388
389impl ChatRequest {
390    /// Sets the model to use for this request.
391    ///
392    /// # Arguments
393    ///
394    /// * `model` - The model identifier (e.g., "gpt-4", "claude-3-opus")
395    pub fn with_model(mut self, model: impl Into<String>) -> Self {
396        self.model = Some(model.into());
397        self
398    }
399
400    /// Sets the sampling temperature.
401    ///
402    /// # Arguments
403    ///
404    /// * `temperature` - Value between 0.0 and 2.0
405    ///
406    /// Higher values produce more random output.
407    pub fn with_temperature(mut self, temperature: f32) -> Self {
408        self.temperature = Some(temperature);
409        self
410    }
411
412    /// Sets the maximum number of tokens to generate.
413    ///
414    /// # Arguments
415    ///
416    /// * `max_tokens` - Maximum tokens in the response
417    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
418        self.max_tokens = Some(max_tokens);
419        self
420    }
421
422    /// Sets the nucleus sampling threshold.
423    ///
424    /// # Arguments
425    ///
426    /// * `top_p` - Value between 0.0 and 1.0
427    ///
428    /// Lower values make output more focused.
429    pub fn with_top_p(mut self, top_p: f32) -> Self {
430        self.top_p = Some(top_p);
431        self
432    }
433
434    /// Sets the frequency penalty.
435    ///
436    /// # Arguments
437    ///
438    /// * `frequency_penalty` - Value between -2.0 and 2.0
439    ///
440    /// Positive values discourage token repetition.
441    pub fn with_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
442        self.frequency_penalty = Some(frequency_penalty);
443        self
444    }
445
446    /// Sets the presence penalty.
447    ///
448    /// # Arguments
449    ///
450    /// * `presence_penalty` - Value between -2.0 and 2.0
451    ///
452    /// Positive values encourage discussing new topics.
453    pub fn with_presence_penalty(mut self, presence_penalty: f32) -> Self {
454        self.presence_penalty = Some(presence_penalty);
455        self
456    }
457
458    /// Sets stop sequences that will halt generation.
459    ///
460    /// # Arguments
461    ///
462    /// * `stop_sequences` - An iterable of strings to use as stop sequences
463    pub fn with_stop_sequences(
464        mut self,
465        stop_sequences: impl IntoIterator<Item = impl Into<String>>,
466    ) -> Self {
467        self.stop = Some(stop_sequences.into_iter().map(|s| s.into()).collect());
468        self
469    }
470
471    /// Sets the tools available for the model to call.
472    ///
473    /// # Arguments
474    ///
475    /// * `tools` - Vector of tool definitions
476    pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
477        self.tools = Some(tools);
478        self
479    }
480
481    /// Sets the tool selection strategy.
482    ///
483    /// # Arguments
484    ///
485    /// * `tool_choice` - Strategy for how the model should select tools
486    pub fn with_tool_choice(mut self, tool_choice: ToolChoice) -> Self {
487        self.tool_choice = Some(tool_choice);
488        self
489    }
490
491    /// Enables or disables streaming for this request.
492    ///
493    /// # Arguments
494    ///
495    /// * `stream` - Whether to stream the response
496    pub fn with_streaming(mut self, stream: bool) -> Self {
497        self.stream = stream;
498        self
499    }
500
501    /// Sets custom metadata for this request.
502    ///
503    /// # Arguments
504    ///
505    /// * `metadata` - Key-value pairs of metadata
506    pub fn with_metadata(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
507        self.metadata = metadata;
508        self
509    }
510
511    /// Enables or disables thinking mode (vendor-specific).
512    ///
513    /// # Arguments
514    ///
515    /// * `enable_thinking` - Whether to enable thinking mode
516    pub fn with_thinking(mut self, enable_thinking: bool) -> Self {
517        self.enable_thinking = Some(enable_thinking);
518        self
519    }
520
521    /// Validate that this request has at least one message.
522    ///
523    /// # Errors
524    ///
525    /// Returns an error if the messages vector is empty.
526    pub fn validate_has_messages(&self) -> anyhow::Result<()> {
527        if self.messages.is_empty() {
528            anyhow::bail!("Chat request must have at least one message");
529        }
530        Ok(())
531    }
532
533    /// Validates all configuration parameters
534    pub fn validate(&self) -> anyhow::Result<()> {
535        self.validate_has_messages()?;
536
537        if let Some(temp) = self.temperature
538            && !(0.0..=2.0).contains(&temp)
539        {
540            anyhow::bail!("Temperature must be between 0.0 and 2.0, got {}", temp);
541        }
542
543        if let Some(top_p) = self.top_p
544            && !(0.0..=1.0).contains(&top_p)
545        {
546            anyhow::bail!("top_p must be between 0.0 and 1.0, got {}", top_p);
547        }
548
549        if let Some(freq_penalty) = self.frequency_penalty
550            && !(-2.0..=2.0).contains(&freq_penalty)
551        {
552            anyhow::bail!(
553                "frequency_penalty must be between -2.0 and 2.0, got {}",
554                freq_penalty
555            );
556        }
557
558        if let Some(pres_penalty) = self.presence_penalty
559            && !(-2.0..=2.0).contains(&pres_penalty)
560        {
561            anyhow::bail!(
562                "presence_penalty must be between -2.0 and 2.0, got {}",
563                pres_penalty
564            );
565        }
566
567        Ok(())
568    }
569
570    /// Returns whether this request has tools configured.
571    ///
572    /// # Returns
573    ///
574    /// `true` if tools are present and non-empty, `false` otherwise.
575    pub fn has_tools(&self) -> bool {
576        self.tools.as_ref().is_some_and(|t| !t.is_empty())
577    }
578
579    /// Returns whether this request uses streaming.
580    ///
581    /// # Returns
582    ///
583    /// `true` if streaming is enabled, `false` otherwise.
584    pub fn is_streaming(&self) -> bool {
585        self.stream
586    }
587}
588
589/// A response from a chat completion request.
590///
591/// Contains the generated message, usage statistics, and metadata about
592/// how and why generation completed.
593///
594/// # Examples
595///
596/// ```no_run
597/// # use neuromance_common::{ChatResponse, Message, MessageRole};
598/// # use neuromance_common::client::FinishReason;
599/// # use uuid::Uuid;
600/// # use chrono::Utc;
601/// # let message = Message::new(Uuid::new_v4(), MessageRole::Assistant, "Hello!");
602/// # let response = ChatResponse {
603/// #     message: message.clone(),
604/// #     model: "gpt-4".to_string(),
605/// #     usage: None,
606/// #     finish_reason: Some(FinishReason::Stop),
607/// #     created_at: Utc::now(),
608/// #     response_id: Some("resp_123".to_string()),
609/// #     metadata: std::collections::HashMap::new(),
610/// # };
611/// // Check why generation stopped
612/// if response.finish_reason == Some(FinishReason::Length) {
613///     println!("Response was truncated - consider increasing max_tokens");
614/// }
615/// ```
616#[derive(Debug, Clone, Serialize, Deserialize)]
617pub struct ChatResponse {
618    /// The generated message from the model.
619    pub message: Message,
620    /// The identifier of the model that generated this response.
621    pub model: String,
622    /// Token usage statistics for this request.
623    pub usage: Option<Usage>,
624    /// Reason why generation stopped.
625    pub finish_reason: Option<FinishReason>,
626    /// Timestamp when this response was created.
627    pub created_at: DateTime<Utc>,
628    /// Unique identifier for this response from the provider.
629    pub response_id: Option<String>,
630    /// Additional metadata about this response.
631    pub metadata: HashMap<String, serde_json::Value>,
632}
633
634/// A chunk from a streaming chat completion.
635///
636/// Represents an incremental update to a chat response. Multiple chunks
637/// are combined to form the complete response. Typically received from
638/// streaming APIs where the response is delivered incrementally.
639#[derive(Debug, Clone, Serialize, Deserialize)]
640pub struct ChatChunk {
641    /// The model identifier that generated this chunk.
642    pub model: String,
643    /// Incremental content added in this chunk.
644    pub delta_content: Option<String>,
645    /// The role of the message (only present in first chunk).
646    pub delta_role: Option<crate::chat::MessageRole>,
647    /// Tool calls being built incrementally.
648    pub delta_tool_calls: Option<Vec<crate::tools::ToolCall>>,
649    /// Reason why generation stopped (only present in final chunk).
650    pub finish_reason: Option<FinishReason>,
651    /// Token usage statistics (only present in final chunk for some providers).
652    pub usage: Option<Usage>,
653    /// Unique identifier for this response stream.
654    pub response_id: Option<String>,
655    /// Timestamp when this chunk was created.
656    pub created_at: DateTime<Utc>,
657    /// Additional metadata about this chunk.
658    pub metadata: HashMap<String, serde_json::Value>,
659}
660
661impl fmt::Display for ChatResponse {
662    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
663        match serde_json::to_string(self) {
664            Ok(json) => write!(f, "{}", json),
665            Err(_) => write!(f, "Error serializing ChatResponse to JSON"),
666        }
667    }
668}
669
670/// Configuration for an LLM client.
671///
672/// This struct holds both connection details (API keys, URLs) and default
673/// generation parameters that will be applied to all requests unless overridden.
674///
675/// # Security
676///
677/// The `api_key` field uses `SecretString` to prevent accidental logging or
678/// display of sensitive credentials.
679///
680/// # Examples
681///
682/// ```
683/// use neuromance_common::Config;
684///
685/// let config = Config::new("openai", "gpt-4")
686///     .with_api_key("sk-...")
687///     .with_temperature(0.7)
688///     .with_max_tokens(1000);
689/// ```
690#[derive(Debug, Clone, Serialize, Deserialize)]
691pub struct Config {
692    /// The LLM provider name (e.g., "openai", "anthropic").
693    pub provider: String,
694    /// The default model identifier to use.
695    pub model: String,
696    /// Optional custom base URL for API requests.
697    ///
698    /// Override this for self-hosted deployments or custom endpoints.
699    pub base_url: Option<String>,
700    /// API key for authentication (stored securely).
701    ///
702    /// Will not be serialized to prevent accidental exposure.
703    #[serde(skip_serializing, default)]
704    pub api_key: Option<SecretString>,
705    /// Optional organization identifier.
706    pub organization: Option<String>,
707    /// Request timeout in seconds.
708    pub timeout_seconds: Option<u64>,
709    /// Configuration for retry behavior with exponential backoff.
710    #[serde(skip)]
711    pub retry_config: RetryConfig,
712    /// Default sampling temperature (0.0 to 2.0).
713    pub temperature: Option<f32>,
714    /// Default maximum tokens to generate.
715    pub max_tokens: Option<u32>,
716    /// Default nucleus sampling threshold (0.0 to 1.0).
717    pub top_p: Option<f32>,
718    /// Default frequency penalty (-2.0 to 2.0).
719    pub frequency_penalty: Option<f32>,
720    /// Default presence penalty (-2.0 to 2.0).
721    pub presence_penalty: Option<f32>,
722    /// Default stop sequences.
723    pub stop_sequences: Option<Vec<String>>,
724    /// Additional metadata to attach to all requests.
725    pub metadata: HashMap<String, serde_json::Value>,
726}
727
728impl Default for Config {
729    fn default() -> Self {
730        Self {
731            provider: "ollama".to_string(),
732            model: "gpt-oss:20b".to_string(),
733            base_url: None,
734            api_key: None,
735            organization: None,
736            timeout_seconds: None,
737            retry_config: RetryConfig::default(),
738            temperature: None,
739            max_tokens: None,
740            top_p: None,
741            frequency_penalty: None,
742            presence_penalty: None,
743            stop_sequences: None,
744            metadata: HashMap::new(),
745        }
746    }
747}
748
749impl Config {
750    /// Creates a new configuration with the specified provider and model.
751    ///
752    /// All optional fields are initialized to their defaults.
753    ///
754    /// # Arguments
755    ///
756    /// * `provider` - The LLM provider name
757    /// * `model` - The model identifier
758    ///
759    /// # Examples
760    ///
761    /// ```
762    /// use neuromance_common::Config;
763    ///
764    /// let config = Config::new("openai", "gpt-4");
765    /// ```
766    pub fn new(provider: impl Into<String>, model: impl Into<String>) -> Self {
767        Self {
768            provider: provider.into(),
769            model: model.into(),
770            ..Default::default()
771        }
772    }
773
774    /// Sets a custom base URL for API requests.
775    ///
776    /// # Arguments
777    ///
778    /// * `base_url` - The base URL for the API
779    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
780        self.base_url = Some(base_url.into());
781        self
782    }
783
784    /// Sets the API key for authentication.
785    ///
786    /// The key is stored securely using `SecretString`.
787    ///
788    /// # Arguments
789    ///
790    /// * `api_key` - The API key
791    pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
792        self.api_key = Some(SecretString::new(api_key.into().into()));
793        self
794    }
795
796    /// Sets the organization identifier.
797    ///
798    /// # Arguments
799    ///
800    /// * `organization` - The organization ID
801    pub fn with_organization(mut self, organization: impl Into<String>) -> Self {
802        self.organization = Some(organization.into());
803        self
804    }
805
806    /// Sets the request timeout.
807    ///
808    /// # Arguments
809    ///
810    /// * `timeout_seconds` - Timeout in seconds
811    pub fn with_timeout(mut self, timeout_seconds: u64) -> Self {
812        self.timeout_seconds = Some(timeout_seconds);
813        self
814    }
815
816    /// Sets the default sampling temperature.
817    ///
818    /// # Arguments
819    ///
820    /// * `temperature` - Value between 0.0 and 2.0
821    pub fn with_temperature(mut self, temperature: f32) -> Self {
822        self.temperature = Some(temperature);
823        self
824    }
825
826    /// Sets the default maximum tokens to generate.
827    ///
828    /// # Arguments
829    ///
830    /// * `max_tokens` - Maximum tokens in responses
831    pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
832        self.max_tokens = Some(max_tokens);
833        self
834    }
835
836    /// Sets the default nucleus sampling threshold.
837    ///
838    /// # Arguments
839    ///
840    /// * `top_p` - Value between 0.0 and 1.0
841    pub fn with_top_p(mut self, top_p: f32) -> Self {
842        self.top_p = Some(top_p);
843        self
844    }
845
846    /// Sets the default frequency penalty.
847    ///
848    /// # Arguments
849    ///
850    /// * `frequency_penalty` - Value between -2.0 and 2.0
851    pub fn with_frequency_penalty(mut self, frequency_penalty: f32) -> Self {
852        self.frequency_penalty = Some(frequency_penalty);
853        self
854    }
855
856    /// Sets the default presence penalty.
857    ///
858    /// # Arguments
859    ///
860    /// * `presence_penalty` - Value between -2.0 and 2.0
861    pub fn with_presence_penalty(mut self, presence_penalty: f32) -> Self {
862        self.presence_penalty = Some(presence_penalty);
863        self
864    }
865
866    /// Sets the default stop sequences.
867    ///
868    /// # Arguments
869    ///
870    /// * `stop_sequences` - An iterable of stop sequences
871    pub fn with_stop_sequences(
872        mut self,
873        stop_sequences: impl IntoIterator<Item = impl Into<String>>,
874    ) -> Self {
875        self.stop_sequences = Some(stop_sequences.into_iter().map(|s| s.into()).collect());
876        self
877    }
878
879    /// Sets the default metadata.
880    ///
881    /// # Arguments
882    ///
883    /// * `metadata` - Key-value pairs of metadata
884    pub fn with_metadata(mut self, metadata: HashMap<String, serde_json::Value>) -> Self {
885        self.metadata = metadata;
886        self
887    }
888
889    /// Sets the retry configuration.
890    ///
891    /// # Arguments
892    ///
893    /// * `retry_config` - The retry configuration
894    pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self {
895        self.retry_config = retry_config;
896        self
897    }
898}
899
900impl From<(Config, Vec<Message>)> for ChatRequest {
901    fn from((config, messages): (Config, Vec<Message>)) -> Self {
902        let mut request = ChatRequest::new(messages).with_model(&config.model);
903
904        if let Some(temperature) = config.temperature {
905            request = request.with_temperature(temperature);
906        }
907
908        if let Some(max_tokens) = config.max_tokens {
909            request = request.with_max_tokens(max_tokens);
910        }
911
912        if let Some(top_p) = config.top_p {
913            request.top_p = Some(top_p);
914        }
915
916        if let Some(frequency_penalty) = config.frequency_penalty {
917            request.frequency_penalty = Some(frequency_penalty);
918        }
919
920        if let Some(presence_penalty) = config.presence_penalty {
921            request.presence_penalty = Some(presence_penalty);
922        }
923
924        if let Some(stop_sequences) = config.stop_sequences {
925            request.stop = Some(stop_sequences);
926        }
927
928        request.metadata = config.metadata;
929
930        request
931    }
932}
933
934impl Config {
935    /// Converts this configuration and messages into a chat request.
936    ///
937    /// This is a convenience method that creates a `ChatRequest` with all
938    /// default parameters from this configuration.
939    ///
940    /// # Arguments
941    ///
942    /// * `messages` - The conversation messages
943    ///
944    /// # Examples
945    ///
946    /// ```
947    /// use neuromance_common::{Config, Message, MessageRole};
948    /// use uuid::Uuid;
949    ///
950    /// let config = Config::new("openai", "gpt-4").with_temperature(0.7);
951    /// let msg = Message::new(Uuid::new_v4(), MessageRole::User, "Hello!");
952    /// let request = config.into_chat_request(vec![msg]);
953    /// ```
954    pub fn into_chat_request(self, messages: Vec<Message>) -> ChatRequest {
955        (self, messages).into()
956    }
957
958    /// Validates the configuration parameters.
959    ///
960    /// Checks that all numeric parameters are within their valid ranges.
961    ///
962    /// # Errors
963    ///
964    /// Returns an error if any parameter is out of range:
965    /// - `temperature` must be between 0.0 and 2.0
966    /// - `top_p` must be between 0.0 and 1.0
967    /// - `frequency_penalty` must be between -2.0 and 2.0
968    /// - `presence_penalty` must be between -2.0 and 2.0
969    pub fn validate(&self) -> anyhow::Result<()> {
970        if let Some(temp) = self.temperature
971            && !(0.0..=2.0).contains(&temp)
972        {
973            anyhow::bail!("Temperature must be between 0.0 and 2.0, got {}", temp);
974        }
975
976        if let Some(top_p) = self.top_p
977            && !(0.0..=1.0).contains(&top_p)
978        {
979            anyhow::bail!("top_p must be between 0.0 and 1.0, got {}", top_p);
980        }
981
982        if let Some(freq_penalty) = self.frequency_penalty
983            && !(-2.0..=2.0).contains(&freq_penalty)
984        {
985            anyhow::bail!(
986                "frequency_penalty must be between -2.0 and 2.0, got {}",
987                freq_penalty
988            );
989        }
990
991        if let Some(pres_penalty) = self.presence_penalty
992            && !(-2.0..=2.0).contains(&pres_penalty)
993        {
994            anyhow::bail!(
995                "presence_penalty must be between -2.0 and 2.0, got {}",
996                pres_penalty
997            );
998        }
999
1000        Ok(())
1001    }
1002}
1003
1004#[cfg(test)]
1005mod proptests {
1006    use super::*;
1007    use proptest::prelude::*;
1008
1009    proptest! {
1010        #[test]
1011        fn temperature_validation(temp in -10.0f32..10.0f32) {
1012            let config = Config::new("openai", "gpt-4").with_temperature(temp);
1013            let is_valid = (0.0..=2.0).contains(&temp);
1014            assert_eq!(config.validate().is_ok(), is_valid);
1015        }
1016
1017        #[test]
1018        fn top_p_validation(top_p in -5.0f32..5.0f32) {
1019            let config = Config::new("openai", "gpt-4").with_top_p(top_p);
1020            let is_valid = (0.0..=1.0).contains(&top_p);
1021            assert_eq!(config.validate().is_ok(), is_valid);
1022        }
1023
1024        #[test]
1025        fn frequency_penalty_validation(penalty in -10.0f32..10.0f32) {
1026            let config = Config::new("openai", "gpt-4").with_frequency_penalty(penalty);
1027            let is_valid = (-2.0..=2.0).contains(&penalty);
1028            assert_eq!(config.validate().is_ok(), is_valid);
1029        }
1030
1031        #[test]
1032        fn presence_penalty_validation(penalty in -10.0f32..10.0f32) {
1033            let config = Config::new("openai", "gpt-4").with_presence_penalty(penalty);
1034            let is_valid = (-2.0..=2.0).contains(&penalty);
1035            assert_eq!(config.validate().is_ok(), is_valid);
1036        }
1037
1038        #[test]
1039        fn max_tokens_validation(tokens in 0u32..1000000u32) {
1040            let config = Config::new("openai", "gpt-4").with_max_tokens(tokens);
1041            // max_tokens can be 0 (infinite) or any positive value
1042            assert!(config.validate().is_ok());
1043        }
1044
1045        #[test]
1046        fn config_builder_with_string_slice(
1047            provider in ".*",
1048            model in ".*",
1049            base_url in ".*",
1050        ) {
1051            let config = Config::new(provider.as_str(), model.as_str())
1052                .with_base_url(base_url.as_str());
1053
1054            // Should compile and work with &str
1055            assert_eq!(config.provider, provider);
1056            assert_eq!(config.model, model);
1057            assert_eq!(config.base_url, Some(base_url));
1058        }
1059
1060        #[test]
1061        fn config_builder_with_owned_string(
1062            provider in ".*",
1063            model in ".*",
1064        ) {
1065            let config = Config::new(provider.clone(), model.clone());
1066
1067            // Should compile and work with String
1068            assert_eq!(config.provider, provider);
1069            assert_eq!(config.model, model);
1070        }
1071
1072        #[test]
1073        fn stop_sequences_accepts_various_types(
1074            sequences in prop::collection::vec(".*", 0..10),
1075        ) {
1076            // Test with Vec<String>
1077            let config1 = Config::new("openai", "gpt-4")
1078                .with_stop_sequences(sequences.clone());
1079            assert_eq!(config1.stop_sequences, Some(sequences.clone()));
1080
1081            // Test with Vec<&str>
1082            let str_refs: Vec<&str> = sequences.iter().map(|s| s.as_str()).collect();
1083            let config2 = Config::new("openai", "gpt-4")
1084                .with_stop_sequences(str_refs);
1085            assert_eq!(config2.stop_sequences, Some(sequences.clone()));
1086
1087            // Test with array of &str
1088            if sequences.len() <= 3 {
1089                let arr: Vec<&str> = sequences.iter().map(|s| s.as_str()).collect();
1090                let config3 = Config::new("openai", "gpt-4")
1091                    .with_stop_sequences(arr);
1092                assert_eq!(config3.stop_sequences, Some(sequences));
1093            }
1094        }
1095
1096        #[test]
1097        fn builder_chain_preserves_all_values(
1098            provider in ".*",
1099            model in ".*",
1100            temp in 0.0f32..2.0f32,
1101            max_tokens in 0u32..100000u32,
1102        ) {
1103            let config = Config::new(provider.as_str(), model.as_str())
1104                .with_temperature(temp)
1105                .with_max_tokens(max_tokens);
1106
1107            assert_eq!(config.provider, provider);
1108            assert_eq!(config.model, model);
1109            assert_eq!(config.temperature, Some(temp));
1110            assert_eq!(config.max_tokens, Some(max_tokens));
1111            assert!(config.validate().is_ok());
1112        }
1113
1114        // ChatRequest property tests
1115        #[test]
1116        fn chat_request_temperature_validation(
1117            temp in -10.0f32..10.0f32,
1118            msg_count in 1usize..10,
1119        ) {
1120            use crate::chat::{Message, MessageRole};
1121            use uuid::Uuid;
1122
1123            let messages: Vec<Message> = (0..msg_count)
1124                .map(|i| Message::new(Uuid::new_v4(), MessageRole::User, format!("message {}", i)))
1125                .collect();
1126
1127            let request = ChatRequest::new(messages).with_temperature(temp);
1128            let is_valid = (0.0..=2.0).contains(&temp);
1129            assert_eq!(request.validate().is_ok(), is_valid);
1130        }
1131
1132        #[test]
1133        fn chat_request_with_string_types(
1134            model in ".*",
1135        ) {
1136            use crate::chat::{Message, MessageRole};
1137            use uuid::Uuid;
1138
1139            let msg = Message::new(Uuid::new_v4(), MessageRole::User, "test");
1140
1141            // Test with &str
1142            let request1 = ChatRequest::new(vec![msg.clone()])
1143                .with_model(model.as_str());
1144            assert_eq!(request1.model, Some(model.clone()));
1145
1146            // Test with String
1147            let request2 = ChatRequest::new(vec![msg])
1148                .with_model(model.clone());
1149            assert_eq!(request2.model, Some(model));
1150        }
1151
1152        #[test]
1153        fn chat_request_stop_sequences_ergonomics(
1154            sequences in prop::collection::vec(".*", 1..5),
1155        ) {
1156            use crate::chat::{Message, MessageRole};
1157            use uuid::Uuid;
1158
1159            let msg = Message::new(Uuid::new_v4(), MessageRole::User, "test");
1160
1161            // Test with Vec<String>
1162            let request1 = ChatRequest::new(vec![msg.clone()])
1163                .with_stop_sequences(sequences.clone());
1164            assert_eq!(request1.stop, Some(sequences.clone()));
1165
1166            // Test with &[&str]
1167            let str_refs: Vec<&str> = sequences.iter().map(|s| s.as_str()).collect();
1168            let request2 = ChatRequest::new(vec![msg])
1169                .with_stop_sequences(str_refs);
1170            assert_eq!(request2.stop, Some(sequences));
1171        }
1172
1173        #[test]
1174        fn chat_request_builder_chain(
1175            model in ".*",
1176            temp in 0.0f32..2.0f32,
1177            max_tokens in 0u32..100000u32,
1178            top_p in 0.0f32..1.0f32,
1179        ) {
1180            use crate::chat::{Message, MessageRole};
1181            use uuid::Uuid;
1182
1183            let msg = Message::new(Uuid::new_v4(), MessageRole::User, "test");
1184            let request = ChatRequest::new(vec![msg])
1185                .with_model(model.as_str())
1186                .with_temperature(temp)
1187                .with_max_tokens(max_tokens)
1188                .with_top_p(top_p);
1189
1190            assert_eq!(request.model, Some(model));
1191            assert_eq!(request.temperature, Some(temp));
1192            assert_eq!(request.max_tokens, Some(max_tokens));
1193            assert_eq!(request.top_p, Some(top_p));
1194            assert!(request.validate().is_ok());
1195        }
1196    }
1197
1198    #[test]
1199    fn chat_request_validates_empty_messages() {
1200        let request = ChatRequest::new(vec![]);
1201        assert!(request.validate().is_err());
1202        assert!(request.validate_has_messages().is_err());
1203    }
1204
1205    #[test]
1206    fn chat_request_has_tools() {
1207        use crate::chat::{Message, MessageRole};
1208        use crate::tools::{Function, Tool};
1209        use uuid::Uuid;
1210
1211        let msg = Message::new(Uuid::new_v4(), MessageRole::User, "test");
1212
1213        // Test with no tools (None)
1214        let request_no_tools = ChatRequest::new(vec![msg.clone()]);
1215        assert!(!request_no_tools.has_tools());
1216
1217        // Test with empty tools vector
1218        let request_empty_tools = ChatRequest::new(vec![msg.clone()]).with_tools(vec![]);
1219        assert!(!request_empty_tools.has_tools());
1220
1221        // Test with tools present
1222        let function = Function {
1223            name: "test_function".to_string(),
1224            description: "A test function".to_string(),
1225            parameters: serde_json::json!({}),
1226        };
1227        let tool = Tool::builder().function(function).build();
1228        let request_with_tools = ChatRequest::new(vec![msg]).with_tools(vec![tool]);
1229        assert!(request_with_tools.has_tools());
1230    }
1231}