1use std::fmt;
2use std::str::FromStr;
3
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
22#[serde(rename_all = "lowercase")]
23pub enum Transport {
24 Stdio,
26 Http,
28 Sse,
30}
31
32impl fmt::Display for Transport {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 match self {
35 Self::Stdio => write!(f, "stdio"),
36 Self::Http => write!(f, "http"),
37 Self::Sse => write!(f, "sse"),
38 }
39 }
40}
41
42#[derive(Debug, Clone, thiserror::Error)]
44#[error("unknown transport: {0}")]
45pub struct TransportParseError(pub String);
46
47impl FromStr for Transport {
48 type Err = TransportParseError;
49
50 fn from_str(s: &str) -> Result<Self, Self::Err> {
51 match s {
52 "stdio" => Ok(Self::Stdio),
53 "http" => Ok(Self::Http),
54 "sse" => Ok(Self::Sse),
55 other => Err(TransportParseError(other.to_string())),
56 }
57 }
58}
59
60impl TryFrom<&str> for Transport {
61 type Error = TransportParseError;
62
63 fn try_from(s: &str) -> Result<Self, Self::Error> {
64 s.parse()
65 }
66}
67
68#[cfg(test)]
69mod transport_tests {
70 use super::*;
71
72 #[test]
73 fn from_str_accepts_known_variants() {
74 assert_eq!("stdio".parse::<Transport>().unwrap(), Transport::Stdio);
75 assert_eq!("http".parse::<Transport>().unwrap(), Transport::Http);
76 assert_eq!("sse".parse::<Transport>().unwrap(), Transport::Sse);
77 }
78
79 #[test]
80 fn from_str_returns_error_for_unknown() {
81 let err = "websocket".parse::<Transport>().unwrap_err();
82 assert!(err.to_string().contains("websocket"));
83 }
84
85 #[test]
86 fn try_from_str_returns_error_for_unknown() {
87 assert!(Transport::try_from("bogus").is_err());
88 }
89}
90
91#[derive(Debug, Clone, Copy, Default)]
93pub enum OutputFormat {
94 #[default]
96 Text,
97 Json,
99 StreamJson,
101}
102
103impl OutputFormat {
104 pub(crate) fn as_arg(&self) -> &'static str {
105 match self {
106 Self::Text => "text",
107 Self::Json => "json",
108 Self::StreamJson => "stream-json",
109 }
110 }
111}
112
113#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
115#[serde(rename_all = "camelCase")]
116pub enum PermissionMode {
117 #[default]
119 Default,
120 AcceptEdits,
122 BypassPermissions,
124 DontAsk,
126 Plan,
128 Auto,
130}
131
132impl PermissionMode {
133 pub(crate) fn as_arg(&self) -> &'static str {
134 match self {
135 Self::Default => "default",
136 Self::AcceptEdits => "acceptEdits",
137 Self::BypassPermissions => "bypassPermissions",
138 Self::DontAsk => "dontAsk",
139 Self::Plan => "plan",
140 Self::Auto => "auto",
141 }
142 }
143}
144
145#[derive(Debug, Clone, Copy, Default)]
147pub enum InputFormat {
148 #[default]
150 Text,
151 StreamJson,
153}
154
155impl InputFormat {
156 pub(crate) fn as_arg(&self) -> &'static str {
157 match self {
158 Self::Text => "text",
159 Self::StreamJson => "stream-json",
160 }
161 }
162}
163
164#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
166#[serde(rename_all = "lowercase")]
167pub enum Effort {
168 Low,
170 Medium,
172 High,
174 Max,
176}
177
178impl Effort {
179 pub(crate) fn as_arg(&self) -> &'static str {
180 match self {
181 Self::Low => "low",
182 Self::Medium => "medium",
183 Self::High => "high",
184 Self::Max => "max",
185 }
186 }
187}
188
189#[derive(Debug, Clone, Copy, Default)]
191pub enum Scope {
192 #[default]
194 Local,
195 User,
197 Project,
199}
200
201impl Scope {
202 pub(crate) fn as_arg(&self) -> &'static str {
203 match self {
204 Self::Local => "local",
205 Self::User => "user",
206 Self::Project => "project",
207 }
208 }
209}
210
211#[cfg(feature = "json")]
228#[derive(Debug, Clone, Deserialize, Serialize)]
229#[serde(rename_all = "camelCase")]
230pub struct AuthStatus {
231 #[serde(default)]
233 pub logged_in: bool,
234 #[serde(default)]
236 pub auth_method: Option<String>,
237 #[serde(default)]
239 pub api_provider: Option<String>,
240 #[serde(default)]
242 pub email: Option<String>,
243 #[serde(default)]
245 pub org_id: Option<String>,
246 #[serde(default)]
248 pub org_name: Option<String>,
249 #[serde(default)]
251 pub subscription_type: Option<String>,
252 #[serde(flatten)]
254 pub extra: std::collections::HashMap<String, serde_json::Value>,
255}
256
257#[cfg(feature = "json")]
259#[derive(Debug, Clone, Deserialize, Serialize)]
260pub struct QueryMessage {
261 #[serde(default)]
263 pub role: String,
264 #[serde(default)]
266 pub content: serde_json::Value,
267 #[serde(flatten)]
269 pub extra: std::collections::HashMap<String, serde_json::Value>,
270}
271
272#[cfg(feature = "json")]
274#[derive(Debug, Clone, Deserialize, Serialize)]
275pub struct QueryResult {
276 #[serde(default)]
278 pub result: String,
279 #[serde(default)]
281 pub session_id: String,
282 #[serde(default, rename = "total_cost_usd", alias = "cost_usd")]
284 pub cost_usd: Option<f64>,
285 #[serde(default)]
287 pub duration_ms: Option<u64>,
288 #[serde(default)]
290 pub num_turns: Option<u32>,
291 #[serde(default)]
293 pub is_error: bool,
294 #[serde(flatten)]
296 pub extra: std::collections::HashMap<String, serde_json::Value>,
297}
298
299#[cfg(all(test, feature = "json"))]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn query_result_deserializes_total_cost_usd() {
305 let json =
306 r#"{"result":"hello","session_id":"s1","total_cost_usd":0.042,"is_error":false}"#;
307 let qr: QueryResult = serde_json::from_str(json).unwrap();
308 assert_eq!(qr.cost_usd, Some(0.042));
309 }
310
311 #[test]
312 fn query_result_deserializes_cost_usd_alias() {
313 let json = r#"{"result":"hello","session_id":"s1","cost_usd":0.01,"is_error":false}"#;
314 let qr: QueryResult = serde_json::from_str(json).unwrap();
315 assert_eq!(qr.cost_usd, Some(0.01));
316 }
317
318 #[test]
319 fn query_result_missing_cost_defaults_to_none() {
320 let json = r#"{"result":"hello","session_id":"s1","is_error":false}"#;
321 let qr: QueryResult = serde_json::from_str(json).unwrap();
322 assert_eq!(qr.cost_usd, None);
323 }
324
325 #[test]
326 fn query_result_from_stream_result_event() {
327 let json = r#"{"type":"result","subtype":"success","result":"streamed","session_id":"sess-1","total_cost_usd":0.03,"num_turns":1,"is_error":false}"#;
331 let qr: QueryResult = serde_json::from_str(json).unwrap();
332 assert_eq!(qr.cost_usd, Some(0.03));
333 assert_eq!(qr.num_turns, Some(1));
334 assert_eq!(qr.session_id, "sess-1");
335 assert_eq!(qr.result, "streamed");
336 assert_eq!(
337 qr.extra.get("type").and_then(|v| v.as_str()),
338 Some("result")
339 );
340 }
341
342 #[test]
343 fn query_result_deserializes_num_turns() {
344 let json = r#"{"result":"done","session_id":"s2","total_cost_usd":0.1,"num_turns":5,"is_error":false}"#;
345 let qr: QueryResult = serde_json::from_str(json).unwrap();
346 assert_eq!(qr.num_turns, Some(5));
347 assert_eq!(qr.cost_usd, Some(0.1));
348 }
349
350 #[test]
351 fn query_result_serializes_as_total_cost_usd() {
352 let qr = QueryResult {
353 result: "ok".into(),
354 session_id: "s1".into(),
355 cost_usd: Some(0.05),
356 duration_ms: None,
357 num_turns: Some(3),
358 is_error: false,
359 extra: Default::default(),
360 };
361 let json = serde_json::to_string(&qr).unwrap();
362 assert!(json.contains("\"total_cost_usd\""));
363 assert!(json.contains("\"num_turns\""));
364 }
365}