use dashmap::DashMap;
use futures::future::Shared;
use futures::FutureExt;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::OnceCell;
use super::persistent_cache::PersistentCacheClient;
#[derive(Debug)]
pub struct Context<CTXEXT, PC> {
pub ext: Arc<CTXEXT>,
pub cost_multiplier: rust_decimal::Decimal,
pub suppress_output: bool,
persistent_cache: Arc<PC>,
objectiveai_authorization: Option<Arc<String>>,
openrouter_authorization: Option<Arc<String>>,
github_authorization: Option<Arc<String>>,
mcp_authorization: Option<Arc<HashMap<String, String>>>,
viewer_signature: Option<Arc<String>>,
viewer_address: Option<Arc<String>>,
commit_author_name: Option<Arc<String>>,
commit_author_email: Option<Arc<String>>,
openrouter_authorization_cached: Arc<OnceCell<Option<Arc<String>>>>,
github_authorization_cached: Arc<OnceCell<Option<Arc<String>>>>,
mcp_authorization_cached: Arc<OnceCell<Option<Arc<HashMap<String, String>>>>>,
viewer_signature_cached: Arc<OnceCell<Option<Arc<String>>>>,
viewer_address_cached: Arc<OnceCell<Option<Arc<String>>>>,
commit_author_name_cached: Arc<OnceCell<Option<Arc<String>>>>,
commit_author_email_cached: Arc<OnceCell<Option<Arc<String>>>>,
cancelled: Arc<std::sync::atomic::AtomicBool>,
agent_cache: Arc<
DashMap<
objectiveai_sdk::RemotePath,
Shared<
tokio::sync::oneshot::Receiver<
Result<
Option<objectiveai_sdk::agent::RemoteAgentBaseWithFallbacks>,
objectiveai_sdk::error::ResponseError,
>,
>,
>,
>,
>,
swarm_cache: Arc<
DashMap<
objectiveai_sdk::RemotePath,
Shared<
tokio::sync::oneshot::Receiver<
Result<
Option<objectiveai_sdk::swarm::RemoteSwarmBase>,
objectiveai_sdk::error::ResponseError,
>,
>,
>,
>,
>,
function_cache: Arc<
DashMap<
objectiveai_sdk::RemotePath,
Shared<
tokio::sync::oneshot::Receiver<
Result<
Option<objectiveai_sdk::functions::FullRemoteFunction>,
objectiveai_sdk::error::ResponseError,
>,
>,
>,
>,
>,
profile_cache: Arc<
DashMap<
objectiveai_sdk::RemotePath,
Shared<
tokio::sync::oneshot::Receiver<
Result<
Option<objectiveai_sdk::functions::RemoteProfile>,
objectiveai_sdk::error::ResponseError,
>,
>,
>,
>,
>,
remote_latest_cache: Arc<
DashMap<
objectiveai_sdk::RemotePathCommitOptional,
Shared<
tokio::sync::oneshot::Receiver<
Result<
Option<objectiveai_sdk::RemotePath>,
objectiveai_sdk::error::ResponseError,
>,
>,
>,
>,
>,
}
impl<CTXEXT, PC> Clone for Context<CTXEXT, PC> {
fn clone(&self) -> Self {
Self {
ext: self.ext.clone(),
cost_multiplier: self.cost_multiplier,
suppress_output: self.suppress_output,
persistent_cache: self.persistent_cache.clone(),
objectiveai_authorization: self.objectiveai_authorization.clone(),
openrouter_authorization: self.openrouter_authorization.clone(),
github_authorization: self.github_authorization.clone(),
mcp_authorization: self.mcp_authorization.clone(),
viewer_signature: self.viewer_signature.clone(),
viewer_address: self.viewer_address.clone(),
commit_author_name: self.commit_author_name.clone(),
commit_author_email: self.commit_author_email.clone(),
openrouter_authorization_cached: self.openrouter_authorization_cached.clone(),
github_authorization_cached: self.github_authorization_cached.clone(),
mcp_authorization_cached: self.mcp_authorization_cached.clone(),
viewer_signature_cached: self.viewer_signature_cached.clone(),
viewer_address_cached: self.viewer_address_cached.clone(),
commit_author_name_cached: self.commit_author_name_cached.clone(),
commit_author_email_cached: self.commit_author_email_cached.clone(),
cancelled: self.cancelled.clone(),
swarm_cache: self.swarm_cache.clone(),
agent_cache: self.agent_cache.clone(),
function_cache: self.function_cache.clone(),
profile_cache: self.profile_cache.clone(),
remote_latest_cache: self.remote_latest_cache.clone(),
}
}
}
impl<CTXEXT, PC> Context<CTXEXT, PC> {
pub fn is_cancelled(&self) -> bool {
self.cancelled.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn cancel(&self) {
self.cancelled.store(true, std::sync::atomic::Ordering::Relaxed);
}
pub fn new(
ext: Arc<CTXEXT>,
persistent_cache: Arc<PC>,
cost_multiplier: rust_decimal::Decimal,
suppress_output: bool,
headers: &axum::http::HeaderMap,
) -> Self {
let objectiveai_authorization = headers
.get("X-OBJECTIVEAI-AUTHORIZATION")
.or_else(|| headers.get("OBJECTIVEAI-AUTHORIZATION"))
.or_else(|| headers.get("AUTHORIZATION"))
.and_then(|v| v.to_str().ok())
.map(|s| Arc::new(s.to_owned()));
let openrouter_authorization = headers
.get("X-OPENROUTER-AUTHORIZATION")
.or_else(|| headers.get("OPENROUTER-AUTHORIZATION"))
.and_then(|v| v.to_str().ok())
.map(|s| Arc::new(s.to_owned()));
let github_authorization = headers
.get("X-GITHUB-AUTHORIZATION")
.or_else(|| headers.get("GITHUB-AUTHORIZATION"))
.and_then(|v| v.to_str().ok())
.map(|s| Arc::new(s.to_owned()));
let mcp_authorization = headers
.get("X-MCP-AUTHORIZATION")
.or_else(|| headers.get("MCP-AUTHORIZATION"))
.and_then(|v| v.to_str().ok())
.and_then(|s| serde_json::from_str::<HashMap<String, String>>(s).ok())
.map(Arc::new);
let viewer_signature = headers
.get("X-VIEWER-SIGNATURE")
.or_else(|| headers.get("VIEWER-SIGNATURE"))
.or_else(|| headers.get("X-OBJECTIVEAI-SIGNATURE"))
.or_else(|| headers.get("OBJECTIVEAI-SIGNATURE"))
.and_then(|v| v.to_str().ok())
.map(|s| Arc::new(s.to_owned()));
let viewer_address = headers
.get("X-VIEWER-ADDRESS")
.or_else(|| headers.get("VIEWER-ADDRESS"))
.and_then(|v| v.to_str().ok())
.map(|s| Arc::new(s.to_owned()));
let commit_author_name = headers
.get("X-COMMIT-AUTHOR-NAME")
.or_else(|| headers.get("COMMIT-AUTHOR-NAME"))
.and_then(|v| v.to_str().ok())
.map(|s| Arc::new(s.to_owned()));
let commit_author_email = headers
.get("X-COMMIT-AUTHOR-EMAIL")
.or_else(|| headers.get("COMMIT-AUTHOR-EMAIL"))
.and_then(|v| v.to_str().ok())
.map(|s| Arc::new(s.to_owned()));
Self {
ext,
cost_multiplier,
suppress_output,
persistent_cache,
openrouter_authorization,
github_authorization,
mcp_authorization,
objectiveai_authorization,
viewer_signature,
viewer_address,
commit_author_name,
commit_author_email,
openrouter_authorization_cached: Arc::new(OnceCell::new()),
github_authorization_cached: Arc::new(OnceCell::new()),
mcp_authorization_cached: Arc::new(OnceCell::new()),
viewer_signature_cached: Arc::new(OnceCell::new()),
viewer_address_cached: Arc::new(OnceCell::new()),
commit_author_name_cached: Arc::new(OnceCell::new()),
commit_author_email_cached: Arc::new(OnceCell::new()),
cancelled: Arc::new(std::sync::atomic::AtomicBool::new(false)),
swarm_cache: Arc::new(DashMap::new()),
agent_cache: Arc::new(DashMap::new()),
function_cache: Arc::new(DashMap::new()),
profile_cache: Arc::new(DashMap::new()),
remote_latest_cache: Arc::new(DashMap::new()),
}
}
}
impl<CTXEXT, PC> Context<CTXEXT, PC> {
pub fn objectiveai_authorization(&self) -> Option<&Arc<String>> {
self.objectiveai_authorization.as_ref()
}
}
async fn cached_get_or_fetch<K, V, PC, F, Fut>(
cache: &DashMap<
K,
Shared<tokio::sync::oneshot::Receiver<Result<Option<V>, objectiveai_sdk::error::ResponseError>>>,
>,
persistent_cache: &Arc<PC>,
namespace: &'static str,
key: K,
permanent: bool,
fetch: F,
) -> Result<Option<V>, objectiveai_sdk::error::ResponseError>
where
K: std::hash::Hash + Eq + serde::Serialize + Clone + Send + Sync + 'static,
V: serde::Serialize + serde::de::DeserializeOwned + Clone + Send + Sync + 'static,
PC: PersistentCacheClient,
F: FnOnce() -> Fut + Send + 'static,
Fut: std::future::Future<Output = Result<Option<V>, objectiveai_sdk::error::ResponseError>> + Send,
{
let persistent_cache = persistent_cache.clone();
let persistent_key = serde_json::to_string(&key).unwrap();
let shared = cache
.entry(key)
.or_insert_with(|| {
let (tx, rx) = tokio::sync::oneshot::channel();
tokio::spawn(async move {
let from_persistent = persistent_cache
.get(namespace, &persistent_key)
.await
.ok()
.flatten()
.and_then(|s| serde_json::from_str::<V>(&s).ok());
if let Some(value) = from_persistent {
let _ = tx.send(Ok(Some(value)));
} else {
let result = fetch().await;
let json_to_persist = match &result {
Ok(Some(value)) => serde_json::to_string(value).ok(),
_ => None,
};
let _ = tx.send(result);
if let Some(json) = json_to_persist {
let _ = persistent_cache.set(namespace, &persistent_key, &json, permanent).await;
}
}
});
rx.shared()
})
.clone();
shared.await.unwrap()
}
impl<CTXEXT, PC: PersistentCacheClient> Context<CTXEXT, PC> {
pub async fn cached_agent<F, Fut>(
&self,
key: objectiveai_sdk::RemotePath,
fetch: F,
) -> Result<Option<objectiveai_sdk::agent::RemoteAgentBaseWithFallbacks>, objectiveai_sdk::error::ResponseError>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: std::future::Future<Output = Result<Option<objectiveai_sdk::agent::RemoteAgentBaseWithFallbacks>, objectiveai_sdk::error::ResponseError>> + Send,
{
cached_get_or_fetch(&self.agent_cache, &self.persistent_cache, "agent", key, true, fetch).await
}
pub async fn cached_swarm<F, Fut>(
&self,
key: objectiveai_sdk::RemotePath,
fetch: F,
) -> Result<Option<objectiveai_sdk::swarm::RemoteSwarmBase>, objectiveai_sdk::error::ResponseError>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: std::future::Future<Output = Result<Option<objectiveai_sdk::swarm::RemoteSwarmBase>, objectiveai_sdk::error::ResponseError>> + Send,
{
cached_get_or_fetch(&self.swarm_cache, &self.persistent_cache, "swarm", key, true, fetch).await
}
pub async fn cached_function<F, Fut>(
&self,
key: objectiveai_sdk::RemotePath,
fetch: F,
) -> Result<Option<objectiveai_sdk::functions::FullRemoteFunction>, objectiveai_sdk::error::ResponseError>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: std::future::Future<Output = Result<Option<objectiveai_sdk::functions::FullRemoteFunction>, objectiveai_sdk::error::ResponseError>> + Send,
{
cached_get_or_fetch(&self.function_cache, &self.persistent_cache, "function", key, true, fetch).await
}
pub async fn cached_profile<F, Fut>(
&self,
key: objectiveai_sdk::RemotePath,
fetch: F,
) -> Result<Option<objectiveai_sdk::functions::RemoteProfile>, objectiveai_sdk::error::ResponseError>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: std::future::Future<Output = Result<Option<objectiveai_sdk::functions::RemoteProfile>, objectiveai_sdk::error::ResponseError>> + Send,
{
cached_get_or_fetch(&self.profile_cache, &self.persistent_cache, "profile", key, true, fetch).await
}
pub async fn cached_remote_latest<F, Fut>(
&self,
key: objectiveai_sdk::RemotePathCommitOptional,
fetch: F,
) -> Result<Option<objectiveai_sdk::RemotePath>, objectiveai_sdk::error::ResponseError>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: std::future::Future<Output = Result<Option<objectiveai_sdk::RemotePath>, objectiveai_sdk::error::ResponseError>> + Send,
{
cached_get_or_fetch(&self.remote_latest_cache, &self.persistent_cache, "remote_latest", key, false, fetch).await
}
}
impl<CTXEXT: super::ContextExt, PC> Context<CTXEXT, PC> {
pub async fn upstream_authorization(
&self,
upstream: objectiveai_sdk::agent::Upstream,
) -> Option<Arc<String>> {
if upstream != objectiveai_sdk::agent::Upstream::Openrouter {
return None;
}
self.openrouter_authorization_cached
.get_or_init(|| async {
match (&self.openrouter_authorization, self.ext.openrouter_authorization().await) {
(Some(self_token), _) => Some(self_token.clone()),
(None, byok) => byok,
}
})
.await
.clone()
}
pub async fn github_authorization(&self) -> Option<Arc<String>> {
self.github_authorization_cached
.get_or_init(|| async {
match (&self.github_authorization, self.ext.github_authorization().await) {
(Some(self_token), _) => Some(self_token.clone()),
(None, byok) => byok,
}
})
.await
.clone()
}
pub async fn mcp_authorization(&self) -> Option<Arc<HashMap<String, String>>> {
self.mcp_authorization_cached
.get_or_init(|| async {
let byok: Option<Arc<HashMap<String, String>>> = self.ext.mcp_authorization().await;
match (&self.mcp_authorization, byok) {
(None, None) => None,
(Some(self_headers), None) => Some(self_headers.clone()),
(None, Some(byok_headers)) => Some(byok_headers),
(Some(self_headers), Some(byok_headers)) => {
let mut merged: HashMap<String, String> = (**self_headers).clone();
for (k, v) in byok_headers.iter() {
merged.insert(k.clone(), v.clone());
}
Some(Arc::new(merged))
}
}
})
.await
.clone()
}
pub async fn viewer_signature(&self) -> Option<Arc<String>> {
self.viewer_signature_cached
.get_or_init(|| async {
match (&self.viewer_signature, self.ext.viewer_signature().await) {
(Some(self_sig), _) => Some(self_sig.clone()),
(None, byok) => byok,
}
})
.await
.clone()
}
pub async fn viewer_address(&self) -> Option<Arc<String>> {
self.viewer_address_cached
.get_or_init(|| async {
match (&self.viewer_address, self.ext.viewer_address().await) {
(Some(self_addr), _) => Some(self_addr.clone()),
(None, byok) => byok,
}
})
.await
.clone()
}
pub async fn commit_author_name(&self) -> Option<Arc<String>> {
self.commit_author_name_cached
.get_or_init(|| async {
match (&self.commit_author_name, self.ext.commit_author_name().await) {
(Some(self_name), _) => Some(self_name.clone()),
(None, ext) => ext,
}
})
.await
.clone()
}
pub async fn commit_author_email(&self) -> Option<Arc<String>> {
self.commit_author_email_cached
.get_or_init(|| async {
match (&self.commit_author_email, self.ext.commit_author_email().await) {
(Some(self_email), _) => Some(self_email.clone()),
(None, ext) => ext,
}
})
.await
.clone()
}
}