use std::fmt;
use std::str::FromStr;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Transport {
Stdio,
Http,
Sse,
}
impl fmt::Display for Transport {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Stdio => write!(f, "stdio"),
Self::Http => write!(f, "http"),
Self::Sse => write!(f, "sse"),
}
}
}
#[derive(Debug, Clone, thiserror::Error)]
#[error("unknown transport: {0}")]
pub struct TransportParseError(pub String);
impl FromStr for Transport {
type Err = TransportParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"stdio" => Ok(Self::Stdio),
"http" => Ok(Self::Http),
"sse" => Ok(Self::Sse),
other => Err(TransportParseError(other.to_string())),
}
}
}
impl TryFrom<&str> for Transport {
type Error = TransportParseError;
fn try_from(s: &str) -> Result<Self, Self::Error> {
s.parse()
}
}
#[cfg(test)]
mod transport_tests {
use super::*;
#[test]
fn from_str_accepts_known_variants() {
assert_eq!("stdio".parse::<Transport>().unwrap(), Transport::Stdio);
assert_eq!("http".parse::<Transport>().unwrap(), Transport::Http);
assert_eq!("sse".parse::<Transport>().unwrap(), Transport::Sse);
}
#[test]
fn from_str_returns_error_for_unknown() {
let err = "websocket".parse::<Transport>().unwrap_err();
assert!(err.to_string().contains("websocket"));
}
#[test]
fn try_from_str_returns_error_for_unknown() {
assert!(Transport::try_from("bogus").is_err());
}
}
#[derive(Debug, Clone, Copy, Default)]
pub enum OutputFormat {
#[default]
Text,
Json,
StreamJson,
}
impl OutputFormat {
pub(crate) fn as_arg(&self) -> &'static str {
match self {
Self::Text => "text",
Self::Json => "json",
Self::StreamJson => "stream-json",
}
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum PermissionMode {
#[default]
Default,
AcceptEdits,
#[deprecated(
since = "0.5.1",
note = "use claude_wrapper::dangerous::DangerousClient instead; \
direct BypassPermissions usage is a footgun and will go \
away in a future major release"
)]
BypassPermissions,
DontAsk,
Plan,
Auto,
}
impl PermissionMode {
pub(crate) fn as_arg(&self) -> &'static str {
match self {
Self::Default => "default",
Self::AcceptEdits => "acceptEdits",
#[allow(deprecated)]
Self::BypassPermissions => "bypassPermissions",
Self::DontAsk => "dontAsk",
Self::Plan => "plan",
Self::Auto => "auto",
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub enum InputFormat {
#[default]
Text,
StreamJson,
}
impl InputFormat {
pub(crate) fn as_arg(&self) -> &'static str {
match self {
Self::Text => "text",
Self::StreamJson => "stream-json",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Effort {
Low,
Medium,
High,
Max,
}
impl Effort {
pub(crate) fn as_arg(&self) -> &'static str {
match self {
Self::Low => "low",
Self::Medium => "medium",
Self::High => "high",
Self::Max => "max",
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub enum Scope {
#[default]
Local,
User,
Project,
}
impl Scope {
pub(crate) fn as_arg(&self) -> &'static str {
match self {
Self::Local => "local",
Self::User => "user",
Self::Project => "project",
}
}
}
#[cfg(feature = "json")]
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct AuthStatus {
#[serde(default)]
pub logged_in: bool,
#[serde(default)]
pub auth_method: Option<String>,
#[serde(default)]
pub api_provider: Option<String>,
#[serde(default)]
pub email: Option<String>,
#[serde(default)]
pub org_id: Option<String>,
#[serde(default)]
pub org_name: Option<String>,
#[serde(default)]
pub subscription_type: Option<String>,
#[serde(flatten)]
pub extra: std::collections::HashMap<String, serde_json::Value>,
}
#[cfg(feature = "json")]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct QueryMessage {
#[serde(default)]
pub role: String,
#[serde(default)]
pub content: serde_json::Value,
#[serde(flatten)]
pub extra: std::collections::HashMap<String, serde_json::Value>,
}
#[cfg(feature = "json")]
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct QueryResult {
#[serde(default)]
pub result: String,
#[serde(default)]
pub session_id: String,
#[serde(default, rename = "total_cost_usd", alias = "cost_usd")]
pub cost_usd: Option<f64>,
#[serde(default)]
pub duration_ms: Option<u64>,
#[serde(default)]
pub num_turns: Option<u32>,
#[serde(default)]
pub is_error: bool,
#[serde(flatten)]
pub extra: std::collections::HashMap<String, serde_json::Value>,
}
#[cfg(all(test, feature = "json"))]
mod tests {
use super::*;
#[test]
fn query_result_deserializes_total_cost_usd() {
let json =
r#"{"result":"hello","session_id":"s1","total_cost_usd":0.042,"is_error":false}"#;
let qr: QueryResult = serde_json::from_str(json).unwrap();
assert_eq!(qr.cost_usd, Some(0.042));
}
#[test]
fn query_result_deserializes_cost_usd_alias() {
let json = r#"{"result":"hello","session_id":"s1","cost_usd":0.01,"is_error":false}"#;
let qr: QueryResult = serde_json::from_str(json).unwrap();
assert_eq!(qr.cost_usd, Some(0.01));
}
#[test]
fn query_result_missing_cost_defaults_to_none() {
let json = r#"{"result":"hello","session_id":"s1","is_error":false}"#;
let qr: QueryResult = serde_json::from_str(json).unwrap();
assert_eq!(qr.cost_usd, None);
}
#[test]
fn query_result_from_stream_result_event() {
let json = r#"{"type":"result","subtype":"success","result":"streamed","session_id":"sess-1","total_cost_usd":0.03,"num_turns":1,"is_error":false}"#;
let qr: QueryResult = serde_json::from_str(json).unwrap();
assert_eq!(qr.cost_usd, Some(0.03));
assert_eq!(qr.num_turns, Some(1));
assert_eq!(qr.session_id, "sess-1");
assert_eq!(qr.result, "streamed");
assert_eq!(
qr.extra.get("type").and_then(|v| v.as_str()),
Some("result")
);
}
#[test]
fn query_result_deserializes_num_turns() {
let json = r#"{"result":"done","session_id":"s2","total_cost_usd":0.1,"num_turns":5,"is_error":false}"#;
let qr: QueryResult = serde_json::from_str(json).unwrap();
assert_eq!(qr.num_turns, Some(5));
assert_eq!(qr.cost_usd, Some(0.1));
}
#[test]
fn query_result_serializes_as_total_cost_usd() {
let qr = QueryResult {
result: "ok".into(),
session_id: "s1".into(),
cost_usd: Some(0.05),
duration_ms: None,
num_turns: Some(3),
is_error: false,
extra: Default::default(),
};
let json = serde_json::to_string(&qr).unwrap();
assert!(json.contains("\"total_cost_usd\""));
assert!(json.contains("\"num_turns\""));
}
}