use std::ffi::c_void;
use super::{
collector::{StreamTiming, Timing, Usage},
http::{HTTPRequest, HTTPResponse, SSEResponse},
RawObject, RawObjectTrait,
};
use crate::{baml_unreachable, codec::traits::DecodeHandle, proto::baml_cffi_v1::BamlObjectType};
define_raw_object_wrapper! {
LLMCall => ObjectLlmCall
}
impl LLMCall {
pub fn request_id(&self) -> String {
self.raw.call_method("http_request_id", ())
}
pub fn client_name(&self) -> String {
self.raw.call_method("client_name", ())
}
pub fn provider(&self) -> String {
self.raw.call_method("provider", ())
}
pub fn http_request(&self) -> Option<HTTPRequest> {
self.raw
.call_method_for_object_optional("http_request", ())
.unwrap_or_else(|e| baml_unreachable!("Failed to get HTTP request: {e}"))
}
pub fn http_response(&self) -> Option<HTTPResponse> {
self.raw
.call_method_for_object_optional("http_response", ())
.unwrap_or_else(|e| baml_unreachable!("Failed to get HTTP response: {e}"))
}
pub fn usage(&self) -> Option<Usage> {
self.raw
.call_method_for_object_optional("usage", ())
.unwrap_or_else(|e| baml_unreachable!("Failed to get usage: {e}"))
}
pub fn selected(&self) -> bool {
self.raw.call_method("selected", ())
}
pub fn timing(&self) -> Timing {
self.raw
.call_method_for_object("timing", ())
.unwrap_or_else(|e| baml_unreachable!("Failed to get timing: {e}"))
}
}
define_raw_object_wrapper! {
LLMStreamCall => ObjectLlmStreamCall
}
impl LLMStreamCall {
pub fn request_id(&self) -> String {
self.raw.call_method("http_request_id", ())
}
pub fn client_name(&self) -> String {
self.raw.call_method("client_name", ())
}
pub fn provider(&self) -> String {
self.raw.call_method("provider", ())
}
pub fn http_request(&self) -> Option<HTTPRequest> {
self.raw
.call_method_for_object_optional("http_request", ())
.unwrap_or_else(|e| baml_unreachable!("Failed to get HTTP request: {e}"))
}
pub fn http_response(&self) -> Option<HTTPResponse> {
self.raw
.call_method_for_object_optional("http_response", ())
.unwrap_or_else(|e| baml_unreachable!("Failed to get HTTP response: {e}"))
}
pub fn usage(&self) -> Option<Usage> {
self.raw
.call_method_for_object_optional("usage", ())
.unwrap_or_else(|e| baml_unreachable!("Failed to get usage: {e}"))
}
pub fn selected(&self) -> bool {
self.raw.call_method("selected", ())
}
pub fn timing(&self) -> StreamTiming {
self.raw
.call_method_for_object("timing", ())
.unwrap_or_else(|e| baml_unreachable!("Failed to get timing: {e}"))
}
pub fn sse_chunks(&self) -> Option<Vec<SSEResponse>> {
self.raw
.call_method_for_objects_optional("sse_chunks", ())
.unwrap_or_else(|e| baml_unreachable!("Failed to get SSE chunks: {e}"))
}
}
#[derive(Clone)]
pub enum LLMCallKind {
Call(LLMCall),
Stream(LLMStreamCall),
}
impl DecodeHandle for LLMCallKind {
fn decode_handle(
handle: crate::proto::baml_cffi_v1::BamlObjectHandle,
runtime: *const c_void,
) -> Result<Self, crate::BamlError> {
match super::object_type_from_handle(&handle)? {
BamlObjectType::ObjectLlmCall => {
Ok(LLMCallKind::Call(LLMCall::decode_handle(handle, runtime)?))
}
BamlObjectType::ObjectLlmStreamCall => Ok(LLMCallKind::Stream(
LLMStreamCall::decode_handle(handle, runtime)?,
)),
other => Err(crate::BamlError::internal(format!(
"invalid LLM call kind handle: {other:?}"
))),
}
}
}
impl LLMCallKind {
pub fn client_name(&self) -> String {
match self {
LLMCallKind::Call(c) => c.client_name(),
LLMCallKind::Stream(s) => s.client_name(),
}
}
pub fn provider(&self) -> String {
match self {
LLMCallKind::Call(c) => c.provider(),
LLMCallKind::Stream(s) => s.provider(),
}
}
pub fn selected(&self) -> bool {
match self {
LLMCallKind::Call(c) => c.selected(),
LLMCallKind::Stream(s) => s.selected(),
}
}
pub fn http_request(&self) -> Option<HTTPRequest> {
match self {
LLMCallKind::Call(c) => c.http_request(),
LLMCallKind::Stream(s) => s.http_request(),
}
}
pub fn usage(&self) -> Option<Usage> {
match self {
LLMCallKind::Call(c) => c.usage(),
LLMCallKind::Stream(s) => s.usage(),
}
}
pub fn as_call(&self) -> Option<&LLMCall> {
match self {
LLMCallKind::Call(c) => Some(c),
LLMCallKind::Stream(_) => None,
}
}
pub fn as_stream(&self) -> Option<&LLMStreamCall> {
match self {
LLMCallKind::Call(_) => None,
LLMCallKind::Stream(s) => Some(s),
}
}
}