1use serde::Serialize;
4use std::time::Duration;
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
11pub enum PipelineMode {
12 #[default]
14 Live,
15 Composed,
17}
18
19impl PipelineMode {
20 pub fn as_str(&self) -> &'static str {
22 match self {
23 PipelineMode::Live => "live",
24 PipelineMode::Composed => "composed",
25 }
26 }
27}
28
29impl std::fmt::Display for PipelineMode {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 write!(f, "{}", self.as_str())
32 }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum Region {
38 ApNortheast2,
40 UsWest2,
42}
43
44impl Region {
45 pub fn endpoint(&self) -> &'static str {
47 match self {
48 Region::ApNortheast2 => "wss://talk.drawdream.co.kr",
49 Region::UsWest2 => "wss://talk.drawdream.ca", }
51 }
52
53 pub fn as_str(&self) -> &'static str {
55 match self {
56 Region::ApNortheast2 => "ap-northeast-2",
57 Region::UsWest2 => "us-west-2",
58 }
59 }
60}
61
62impl std::fmt::Display for Region {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self.as_str())
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct Config {
71 pub endpoint: String,
73 pub api_key: String,
75 pub user_id: Option<String>,
77 pub connection_timeout: Duration,
79 pub auto_reconnect: bool,
81 pub max_reconnect_attempts: u32,
83 pub reconnect_delay: Duration,
85 pub debug: bool,
87}
88
89impl Config {
90 pub fn builder() -> ConfigBuilder {
92 ConfigBuilder::default()
93 }
94}
95
96#[derive(Debug, Default)]
98pub struct ConfigBuilder {
99 region: Option<Region>,
100 api_key: Option<String>,
101 user_id: Option<String>,
102 connection_timeout: Option<Duration>,
103 auto_reconnect: Option<bool>,
104 max_reconnect_attempts: Option<u32>,
105 reconnect_delay: Option<Duration>,
106 debug: Option<bool>,
107}
108
109impl ConfigBuilder {
110 pub fn region(mut self, region: Region) -> Self {
125 self.region = Some(region);
126 self
127 }
128
129 pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
131 self.api_key = Some(api_key.into());
132 self
133 }
134
135 pub fn user_id(mut self, user_id: impl Into<String>) -> Self {
149 self.user_id = Some(user_id.into());
150 self
151 }
152
153 pub fn connection_timeout(mut self, timeout: Duration) -> Self {
155 self.connection_timeout = Some(timeout);
156 self
157 }
158
159 pub fn auto_reconnect(mut self, enabled: bool) -> Self {
161 self.auto_reconnect = Some(enabled);
162 self
163 }
164
165 pub fn max_reconnect_attempts(mut self, attempts: u32) -> Self {
167 self.max_reconnect_attempts = Some(attempts);
168 self
169 }
170
171 pub fn reconnect_delay(mut self, delay: Duration) -> Self {
173 self.reconnect_delay = Some(delay);
174 self
175 }
176
177 pub fn debug(mut self, enabled: bool) -> Self {
179 self.debug = Some(enabled);
180 self
181 }
182
183 pub fn build(self) -> Result<Config, ConfigError> {
185 let region = self.region.ok_or(ConfigError::MissingRegion)?;
186 let api_key = self.api_key.ok_or(ConfigError::MissingApiKey)?;
187
188 Ok(Config {
189 endpoint: region.endpoint().to_string(),
190 api_key,
191 user_id: self.user_id,
192 connection_timeout: self.connection_timeout.unwrap_or(Duration::from_secs(30)),
193 auto_reconnect: self.auto_reconnect.unwrap_or(true),
194 max_reconnect_attempts: self.max_reconnect_attempts.unwrap_or(5),
195 reconnect_delay: self.reconnect_delay.unwrap_or(Duration::from_secs(1)),
196 debug: self.debug.unwrap_or(false),
197 })
198 }
199}
200
201#[derive(Debug, thiserror::Error)]
203pub enum ConfigError {
204 #[error("region is required")]
205 MissingRegion,
206 #[error("api_key is required")]
207 MissingApiKey,
208}
209
210#[derive(Debug, Clone, Serialize)]
234pub struct Tool {
235 pub name: String,
237 pub description: String,
239 #[serde(skip_serializing_if = "Option::is_none")]
241 pub parameters: Option<FunctionParameters>,
242}
243
244#[derive(Debug, Clone, Serialize)]
246pub struct FunctionParameters {
247 #[serde(rename = "type")]
249 pub r#type: String,
250 pub properties: serde_json::Value,
252 #[serde(skip_serializing_if = "Vec::is_empty")]
254 pub required: Vec<String>,
255}
256
257impl Tool {
258 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
260 Self {
261 name: name.into(),
262 description: description.into(),
263 parameters: None,
264 }
265 }
266
267 pub fn with_parameters(
269 name: impl Into<String>,
270 description: impl Into<String>,
271 properties: serde_json::Value,
272 required: Vec<String>,
273 ) -> Self {
274 Self {
275 name: name.into(),
276 description: description.into(),
277 parameters: Some(FunctionParameters {
278 r#type: "OBJECT".to_string(),
279 properties,
280 required,
281 }),
282 }
283 }
284}
285
286#[derive(Debug, Clone)]
292pub struct SessionConfig {
293 pub pre_prompt: Option<String>,
295 pub language: Option<String>,
297 pub pipeline_mode: PipelineMode,
299 pub ai_speaks_first: bool,
304 pub allow_harm_category: bool,
307 pub tools: Option<Vec<Tool>>,
310}
311
312impl Default for SessionConfig {
313 fn default() -> Self {
314 Self {
315 pre_prompt: None,
316 language: None,
317 pipeline_mode: PipelineMode::Live,
318 ai_speaks_first: false,
319 allow_harm_category: false,
320 tools: None,
321 }
322 }
323}
324
325impl SessionConfig {
326 pub fn new(pre_prompt: impl Into<String>) -> Self {
328 Self {
329 pre_prompt: Some(pre_prompt.into()),
330 language: None,
331 pipeline_mode: PipelineMode::Live,
332 ai_speaks_first: false,
333 allow_harm_category: false,
334 tools: None,
335 }
336 }
337
338 pub fn empty() -> Self {
340 Self::default()
341 }
342
343 pub fn with_language(mut self, language: impl Into<String>) -> Self {
345 self.language = Some(language.into());
346 self
347 }
348
349 pub fn with_pre_prompt(mut self, pre_prompt: impl Into<String>) -> Self {
351 self.pre_prompt = Some(pre_prompt.into());
352 self
353 }
354
355 pub fn with_pipeline_mode(mut self, mode: PipelineMode) -> Self {
360 self.pipeline_mode = mode;
361 self
362 }
363
364 pub fn with_ai_speaks_first(mut self, enabled: bool) -> Self {
377 self.ai_speaks_first = enabled;
378 self
379 }
380
381 pub fn with_allow_harm_category(mut self, allow: bool) -> Self {
392 self.allow_harm_category = allow;
393 self
394 }
395
396 pub fn with_tools(mut self, tools: Vec<Tool>) -> Self {
410 self.tools = Some(tools);
411 self
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418
419 #[test]
420 fn test_config_builder_with_region() {
421 let config = Config::builder()
422 .region(Region::ApNortheast2)
423 .api_key("test-key")
424 .debug(true)
425 .build()
426 .unwrap();
427
428 assert_eq!(config.endpoint, "wss://talk.drawdream.co.kr");
429 assert_eq!(config.api_key, "test-key");
430 assert!(config.debug);
431 }
432
433 #[test]
434 fn test_config_builder_missing_region() {
435 let result = Config::builder().api_key("test-key").build();
436 assert!(matches!(result, Err(ConfigError::MissingRegion)));
437 }
438
439 #[test]
440 fn test_config_builder_missing_api_key() {
441 let result = Config::builder().region(Region::ApNortheast2).build();
442 assert!(matches!(result, Err(ConfigError::MissingApiKey)));
443 }
444
445 #[test]
446 fn test_region_endpoint() {
447 assert_eq!(Region::ApNortheast2.endpoint(), "wss://talk.drawdream.co.kr");
448 assert_eq!(Region::ApNortheast2.as_str(), "ap-northeast-2");
449 }
450
451 #[test]
452 fn test_session_config() {
453 let config = SessionConfig::new("You are a helpful assistant");
454 assert_eq!(config.pre_prompt, Some("You are a helpful assistant".to_string()));
455 }
456
457 #[test]
458 fn test_session_config_empty() {
459 let config = SessionConfig::empty();
460 assert_eq!(config.pre_prompt, None);
461 }
462
463 #[test]
464 fn test_pipeline_mode_default() {
465 let config = SessionConfig::empty();
466 assert_eq!(config.pipeline_mode, PipelineMode::Live);
467 }
468
469 #[test]
470 fn test_pipeline_mode_composed() {
471 let config = SessionConfig::empty().with_pipeline_mode(PipelineMode::Composed);
472 assert_eq!(config.pipeline_mode, PipelineMode::Composed);
473 assert_eq!(config.pipeline_mode.as_str(), "composed");
474 }
475}