use std::backtrace::{Backtrace, BacktraceStatus};
use std::borrow::{Borrow, Cow};
use std::fmt;
use std::time::Duration;
use crate::types::SessionId;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)]
pub(crate) enum Repr<T: fmt::Debug> {
Simple(T),
SimpleMessage(T, Cow<'static, str>),
Custom(Custom<T>),
}
#[derive(Debug)]
pub(crate) struct Custom<T: fmt::Debug> {
pub(crate) kind: T,
pub(crate) error: Box<dyn std::error::Error + Send + Sync>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum ProtocolErrorKind {
MissingContentLength,
InvalidContentLength(String),
RequestCancelled,
CliStartupTimeout,
CliStartupFailed,
VersionMismatch {
server: u32,
min: u32,
max: u32,
},
VersionChanged {
previous: u32,
current: u32,
},
}
impl fmt::Display for ProtocolErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ProtocolErrorKind::MissingContentLength => {
write!(f, "missing Content-Length header")
}
ProtocolErrorKind::InvalidContentLength(v) => {
write!(f, "invalid Content-Length value: \"{v}\"")
}
ProtocolErrorKind::RequestCancelled => write!(f, "request cancelled"),
ProtocolErrorKind::CliStartupTimeout => {
write!(f, "timed out waiting for CLI to report listening port")
}
ProtocolErrorKind::CliStartupFailed => {
write!(f, "CLI exited before reporting listening port")
}
ProtocolErrorKind::VersionMismatch { server, min, max } => {
write!(
f,
"version mismatch: server={server}, supported={min}\u{2013}{max}"
)
}
ProtocolErrorKind::VersionChanged { previous, current } => {
write!(f, "version changed: was {previous}, now {current}")
}
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum SessionErrorKind {
NotFound(SessionId),
AgentError,
Timeout(Duration),
SendWhileWaiting,
EventLoopClosed,
ElicitationNotSupported,
SessionFsProviderRequired,
InvalidSessionFsConfig,
SessionIdMismatch {
requested: SessionId,
returned: SessionId,
},
}
impl fmt::Display for SessionErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SessionErrorKind::NotFound(id) => write!(f, "session not found: {id}"),
SessionErrorKind::AgentError => write!(f, "agent error"),
SessionErrorKind::Timeout(d) => write!(f, "timed out after {d:?}"),
SessionErrorKind::SendWhileWaiting => {
write!(f, "cannot send while send_and_wait is in flight")
}
SessionErrorKind::EventLoopClosed => {
write!(f, "event loop closed before session reached idle")
}
SessionErrorKind::ElicitationNotSupported => write!(
f,
"elicitation not supported by host \
\u{2014} check session.capabilities().ui.elicitation first"
),
SessionErrorKind::SessionFsProviderRequired => write!(
f,
"session was created on a client with session_fs configured \
but no SessionFsProvider was supplied"
),
SessionErrorKind::InvalidSessionFsConfig => {
write!(f, "invalid SessionFsConfig")
}
SessionErrorKind::SessionIdMismatch {
requested,
returned,
} => write!(
f,
"CLI returned session ID {returned} after SDK registered {requested}"
),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum ErrorKind {
Protocol(ProtocolErrorKind),
Rpc {
code: i32,
},
Session(SessionErrorKind),
Io,
Json,
BinaryNotFound {
name: String,
hint: Option<String>,
},
InvalidConfig,
}
impl fmt::Display for ErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ErrorKind::Protocol(k) => write!(f, "{k}"),
ErrorKind::Rpc { code } => write!(f, "RPC error {code}"),
ErrorKind::Session(k) => write!(f, "{k}"),
ErrorKind::Io => write!(f, "I/O error"),
ErrorKind::Json => write!(f, "JSON error"),
ErrorKind::BinaryNotFound {
name,
hint: Some(h),
} => {
write!(f, "binary not found: {name} ({h})")
}
ErrorKind::BinaryNotFound { name, hint: None } => {
write!(f, "binary not found: {name}")
}
ErrorKind::InvalidConfig => write!(f, "invalid configuration"),
}
}
}
pub struct Error {
repr: Repr<ErrorKind>,
backtrace: Option<Box<Backtrace>>,
}
impl Error {
pub(crate) fn new<E>(kind: ErrorKind, error: E) -> Self
where
E: Into<Box<dyn std::error::Error + Send + Sync>>,
{
Self {
repr: Repr::Custom(Custom {
kind,
error: error.into(),
}),
backtrace: capture_backtrace(),
}
}
pub fn kind(&self) -> &ErrorKind {
match &self.repr {
Repr::Simple(kind)
| Repr::SimpleMessage(kind, ..)
| Repr::Custom(Custom { kind, .. }) => kind,
}
}
pub fn message(&self) -> Option<&str> {
match &self.repr {
Repr::SimpleMessage(_, message) => Some(message.borrow()),
_ => None,
}
}
#[must_use]
pub fn with_message<C>(kind: ErrorKind, message: C) -> Self
where
C: Into<Cow<'static, str>>,
{
Self {
repr: Repr::SimpleMessage(kind, message.into()),
backtrace: capture_backtrace(),
}
}
pub fn is_transport_failure(&self) -> bool {
matches!(self.kind(), ErrorKind::Io)
|| matches!(
self.kind(),
ErrorKind::Protocol(ProtocolErrorKind::RequestCancelled)
)
}
pub fn rpc_code(&self) -> Option<i32> {
match self.kind() {
ErrorKind::Rpc { code } => Some(*code),
_ => None,
}
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.repr {
Repr::Simple(kind) => write!(f, "{kind}"),
Repr::SimpleMessage(kind, message) if matches!(kind, ErrorKind::Rpc { code: _ }) => {
write!(f, "{kind}: {message}")
}
Repr::SimpleMessage(_, message) => write!(f, "{message}"),
Repr::Custom(Custom { kind, error }) if matches!(kind, ErrorKind::Rpc { code: _ }) => {
write!(f, "{kind}: {error}")
}
Repr::Custom(Custom { error, .. }) => write!(f, "{error}"),
}
}
}
impl fmt::Debug for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut dbg = f.debug_struct("Error");
dbg.field("context", &self.repr);
if let Some(backtrace) = &self.backtrace {
return dbg.field("backtrace", backtrace).finish();
}
dbg.finish_non_exhaustive()
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match &self.repr {
Repr::Custom(Custom { error, .. }) => Some(&**error),
_ => None,
}
}
}
impl From<ErrorKind> for Error {
fn from(kind: ErrorKind) -> Self {
Self {
repr: Repr::Simple(kind),
backtrace: capture_backtrace(),
}
}
}
impl From<ProtocolErrorKind> for Error {
fn from(kind: ProtocolErrorKind) -> Self {
Self::from(ErrorKind::Protocol(kind))
}
}
impl From<SessionErrorKind> for Error {
fn from(kind: SessionErrorKind) -> Self {
Self::from(ErrorKind::Session(kind))
}
}
impl From<std::io::Error> for Error {
fn from(error: std::io::Error) -> Self {
Self::new(ErrorKind::Io, error)
}
}
impl From<serde_json::Error> for Error {
fn from(error: serde_json::Error) -> Self {
Self::new(ErrorKind::Json, error)
}
}
#[inline(always)]
fn capture_backtrace() -> Option<Box<Backtrace>> {
let backtrace = Backtrace::capture();
if backtrace.status() == BacktraceStatus::Captured {
Some(Box::new(backtrace))
} else {
None
}
}
#[derive(Debug)]
pub struct StopErrors(pub(crate) Vec<Error>);
impl StopErrors {
pub fn errors(&self) -> &[Error] {
&self.0
}
pub fn into_errors(self) -> Vec<Error> {
self.0
}
}
impl fmt::Display for StopErrors {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0.as_slice() {
[] => write!(f, "stop completed with no errors"),
[only] => write!(f, "stop failed: {only}"),
[first, rest @ ..] => write!(
f,
"stop failed with {n} errors; first: {first}",
n = 1 + rest.len(),
),
}
}
}
impl std::error::Error for StopErrors {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
self.0
.first()
.map(|e| e as &(dyn std::error::Error + 'static))
}
}