use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
#[serde(untagged)]
pub enum SDKMessage {
AssistantMessage(super::SDKAssistantMessage),
UserMessage(super::SDKUserMessage),
UserMessageReplay(super::SDKUserMessageReplay),
ResultMessage(super::SDKResultMessage),
SystemMessage(super::SDKSystemMessage),
PartialAssistantMessage(super::SDKPartialAssistantMessage),
CompactBoundaryMessage(super::SDKCompactBoundaryMessage),
StatusMessage(super::SDKStatusMessage),
HookStartedMessage(super::SDKHookStartedMessage),
HookProgressMessage(super::SDKHookProgressMessage),
HookResponseMessage(super::SDKHookResponseMessage),
ToolProgressMessage(super::SDKToolProgressMessage),
AuthStatusMessage(super::SDKAuthStatusMessage),
TaskNotificationMessage(super::SDKTaskNotificationMessage),
TaskStartedMessage(super::SDKTaskStartedMessage),
FilesPersistedEvent(super::SDKFilesPersistedEvent),
ToolUseSummaryMessage(super::SDKToolUseSummaryMessage),
RateLimitEvent(super::SDKRateLimitEvent),
}
impl SDKMessage {
pub fn session_id(&self) -> Option<&str> {
match self {
Self::PartialAssistantMessage(msg) => Some(&msg.session_id),
Self::ResultMessage(msg) => Some(msg.session_id()),
Self::UserMessage(msg) => Some(&msg.session_id),
Self::AssistantMessage(msg) => Some(&msg.session_id),
_ => None,
}
}
pub fn into_downstream(
self,
id: String,
created: u64,
agent: String,
assistant_index: u64,
is_byok: bool,
cost_multiplier: rust_decimal::Decimal,
upstream: objectiveai_sdk::agent::Upstream,
) -> Option<
Result<
objectiveai_sdk::agent::completions::response::streaming::AgentCompletionChunk,
super::super::Error,
>,
> {
match self {
Self::PartialAssistantMessage(msg) => {
msg.into_downstream(id, created, agent, assistant_index, upstream).map(Ok)
}
Self::UserMessage(msg) => {
msg.into_downstream(id, created, assistant_index, upstream).map(Ok)
}
Self::ResultMessage(msg) => {
Some(Ok(msg.into_downstream(id, created, agent, assistant_index, is_byok, cost_multiplier, upstream)))
}
Self::RateLimitEvent(evt) => {
use super::sdk_rate_limit_event::{RateLimitEventType, RateLimitStatus};
let rejected = evt
.rate_limit_info
.and_then(|i| i.status)
.map(|s| matches!(s, RateLimitStatus::Rejected))
.unwrap_or(false);
let terminal = matches!(evt.r#type, RateLimitEventType::RateLimit);
if rejected || terminal {
Some(Err(super::super::Error::RateLimit))
} else {
None
}
}
Self::AssistantMessage(_)
| Self::UserMessageReplay(_)
| Self::SystemMessage(_)
| Self::CompactBoundaryMessage(_)
| Self::StatusMessage(_)
| Self::HookStartedMessage(_)
| Self::HookProgressMessage(_)
| Self::HookResponseMessage(_)
| Self::ToolProgressMessage(_)
| Self::AuthStatusMessage(_)
| Self::TaskNotificationMessage(_)
| Self::TaskStartedMessage(_)
| Self::FilesPersistedEvent(_)
| Self::ToolUseSummaryMessage(_) => None,
}
}
}