use anda_core::{
BaseContext, BoxError, CacheExpiry, CacheFeatures, CacheStoreFeatures, CancellationToken,
CanisterCaller, HttpFeatures, Json, KeysFeatures, ObjectMeta, Path, PutMode, PutResult,
RequestMeta, StateFeatures, StoreFeatures, ToolInput, ToolOutput, derivation_path_with,
};
use bytes::Bytes;
use candid::{CandidType, Principal, utils::ArgumentEncoder};
use http::Extensions;
use parking_lot::RwLock;
use serde::{Serialize, de::DeserializeOwned};
use std::{
collections::BTreeSet,
future::Future,
sync::Arc,
time::{Duration, Instant},
};
const CONTEXT_MAX_DEPTH: u8 = 42;
const CACHE_MAX_CAPACITY: u64 = 1000000;
use super::{
RemoteEngines,
cache::CacheService,
web3::{Web3Client, Web3SDK},
};
use crate::store::Store;
#[derive(Clone)]
pub struct BaseCtx {
pub agent: String,
pub(crate) id: Principal,
pub(crate) name: String,
pub(crate) caller: Principal,
pub(crate) path: Path,
pub(crate) cancellation_token: CancellationToken,
pub(crate) start_at: Instant,
pub(crate) depth: u8,
pub(crate) web3: Arc<Web3SDK>,
pub(crate) remote: Arc<RemoteEngines>,
pub(crate) state: Arc<RwLock<Extensions>>,
pub(crate) meta: RequestMeta,
cache: Arc<CacheService>,
store: Store,
}
impl BaseCtx {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
id: Principal,
name: String,
agent: String,
cancellation_token: CancellationToken,
names: BTreeSet<Path>,
web3: Arc<Web3SDK>,
store: Store,
remote: Arc<RemoteEngines>,
) -> Self {
let caller = Principal::anonymous();
Self {
id,
name: name.clone(),
agent,
caller,
path: Path::default(),
cancellation_token,
start_at: Instant::now(),
cache: Arc::new(CacheService::new(CACHE_MAX_CAPACITY, names)),
store,
web3,
depth: 0,
remote,
state: Arc::new(RwLock::new(Extensions::default())),
meta: RequestMeta::default(),
}
}
pub(crate) fn child(&self, path: String) -> Result<Self, BoxError> {
let path = Path::parse(path)?;
let child = Self {
id: self.id,
name: self.name.clone(),
agent: self.agent.clone(),
caller: self.caller,
path,
cancellation_token: self.cancellation_token.child_token(),
start_at: self.start_at,
cache: self.cache.clone(),
store: self.store.clone(),
web3: self.web3.clone(),
depth: self.depth + 1,
remote: self.remote.clone(),
state: self.state.clone(),
meta: self.meta.clone(),
};
if child.depth >= CONTEXT_MAX_DEPTH {
return Err("Context depth limit exceeded".into());
}
Ok(child)
}
pub(crate) fn child_with(
&self,
caller: Principal,
agent: String,
path: String,
meta: RequestMeta,
) -> Result<Self, BoxError> {
let path = Path::parse(path)?;
let child = Self {
id: self.id,
name: self.name.clone(),
agent,
caller,
path,
cancellation_token: self.cancellation_token.child_token(),
start_at: Instant::now(),
cache: self.cache.clone(),
store: self.store.clone(),
web3: self.web3.clone(),
depth: self.depth + 1,
remote: self.remote.clone(),
state: self.state.clone(),
meta,
};
if child.depth >= CONTEXT_MAX_DEPTH {
return Err("Context depth limit exceeded".into());
}
Ok(child)
}
pub(crate) fn self_meta(&self, target: Principal) -> RequestMeta {
RequestMeta {
engine: Some(target),
user: Some(self.name.clone()),
..Default::default()
}
}
pub fn get_state<T>(&self) -> Option<T>
where
T: Clone + Send + Sync + 'static,
{
self.state.read().get::<T>().cloned()
}
pub fn set_state<T>(&self, v: T) -> Option<T>
where
T: Clone + Send + Sync + 'static,
{
self.state.write().insert(v)
}
}
impl BaseContext for BaseCtx {
async fn remote_tool_call(
&self,
endpoint: &str,
mut args: ToolInput<Json>,
) -> Result<ToolOutput<Json>, BoxError> {
let target = self
.remote
.get_id_by_endpoint(endpoint)
.ok_or_else(|| format!("remote engine endpoint {} not found", endpoint))?;
args.meta = Some(self.self_meta(target));
self.https_signed_rpc(endpoint, "tool_call", &(&args,))
.await
}
}
impl CacheStoreFeatures for BaseCtx {}
impl StateFeatures for BaseCtx {
fn engine_id(&self) -> &Principal {
&self.id
}
fn engine_name(&self) -> &str {
&self.name
}
fn caller(&self) -> &Principal {
&self.caller
}
fn meta(&self) -> &RequestMeta {
&self.meta
}
fn cancellation_token(&self) -> CancellationToken {
self.cancellation_token.clone()
}
fn time_elapsed(&self) -> Duration {
self.start_at.elapsed()
}
}
impl KeysFeatures for BaseCtx {
async fn a256gcm_key(&self, derivation_path: Vec<Vec<u8>>) -> Result<[u8; 32], BoxError> {
match self.web3.as_ref() {
Web3SDK::Tee(cli) => {
cli.a256gcm_key(derivation_path_with(&self.path, derivation_path))
.await
}
Web3SDK::Web3(Web3Client { client: cli }) => {
cli.a256gcm_key(derivation_path_with(&self.path, derivation_path))
.await
}
}
}
async fn ed25519_sign_message(
&self,
derivation_path: Vec<Vec<u8>>,
message: &[u8],
) -> Result<[u8; 64], BoxError> {
match self.web3.as_ref() {
Web3SDK::Tee(cli) => {
cli.ed25519_sign_message(derivation_path_with(&self.path, derivation_path), message)
.await
}
Web3SDK::Web3(Web3Client { client: cli }) => {
cli.ed25519_sign_message(derivation_path_with(&self.path, derivation_path), message)
.await
}
}
}
async fn ed25519_verify(
&self,
derivation_path: Vec<Vec<u8>>,
message: &[u8],
signature: &[u8],
) -> Result<(), BoxError> {
match self.web3.as_ref() {
Web3SDK::Tee(cli) => {
cli.ed25519_verify(
derivation_path_with(&self.path, derivation_path),
message,
signature,
)
.await
}
Web3SDK::Web3(Web3Client { client: cli }) => {
cli.ed25519_verify(
derivation_path_with(&self.path, derivation_path),
message,
signature,
)
.await
}
}
}
async fn ed25519_public_key(
&self,
derivation_path: Vec<Vec<u8>>,
) -> Result<[u8; 32], BoxError> {
match self.web3.as_ref() {
Web3SDK::Tee(cli) => {
cli.ed25519_public_key(derivation_path_with(&self.path, derivation_path))
.await
}
Web3SDK::Web3(Web3Client { client: cli }) => {
cli.ed25519_public_key(derivation_path_with(&self.path, derivation_path))
.await
}
}
}
async fn secp256k1_sign_message_bip340(
&self,
derivation_path: Vec<Vec<u8>>,
message: &[u8],
) -> Result<[u8; 64], BoxError> {
match self.web3.as_ref() {
Web3SDK::Tee(cli) => {
cli.secp256k1_sign_message_bip340(
derivation_path_with(&self.path, derivation_path),
message,
)
.await
}
Web3SDK::Web3(Web3Client { client: cli }) => {
cli.secp256k1_sign_message_bip340(
derivation_path_with(&self.path, derivation_path),
message,
)
.await
}
}
}
async fn secp256k1_verify_bip340(
&self,
derivation_path: Vec<Vec<u8>>,
message: &[u8],
signature: &[u8],
) -> Result<(), BoxError> {
match self.web3.as_ref() {
Web3SDK::Tee(cli) => {
cli.secp256k1_verify_bip340(
derivation_path_with(&self.path, derivation_path),
message,
signature,
)
.await
}
Web3SDK::Web3(Web3Client { client: cli }) => {
cli.secp256k1_verify_bip340(
derivation_path_with(&self.path, derivation_path),
message,
signature,
)
.await
}
}
}
async fn secp256k1_sign_message_ecdsa(
&self,
derivation_path: Vec<Vec<u8>>,
message: &[u8],
) -> Result<[u8; 64], BoxError> {
match self.web3.as_ref() {
Web3SDK::Tee(cli) => {
cli.secp256k1_sign_message_ecdsa(
derivation_path_with(&self.path, derivation_path),
message,
)
.await
}
Web3SDK::Web3(Web3Client { client: cli }) => {
cli.secp256k1_sign_message_ecdsa(
derivation_path_with(&self.path, derivation_path),
message,
)
.await
}
}
}
async fn secp256k1_sign_digest_ecdsa(
&self,
derivation_path: Vec<Vec<u8>>,
message_hash: &[u8],
) -> Result<[u8; 64], BoxError> {
match self.web3.as_ref() {
Web3SDK::Tee(cli) => {
cli.secp256k1_sign_digest_ecdsa(
derivation_path_with(&self.path, derivation_path),
message_hash,
)
.await
}
Web3SDK::Web3(Web3Client { client: cli }) => {
cli.secp256k1_sign_digest_ecdsa(
derivation_path_with(&self.path, derivation_path),
message_hash,
)
.await
}
}
}
async fn secp256k1_verify_ecdsa(
&self,
derivation_path: Vec<Vec<u8>>,
message_hash: &[u8],
signature: &[u8],
) -> Result<(), BoxError> {
match self.web3.as_ref() {
Web3SDK::Tee(cli) => {
cli.secp256k1_verify_ecdsa(
derivation_path_with(&self.path, derivation_path),
message_hash,
signature,
)
.await
}
Web3SDK::Web3(Web3Client { client: cli }) => {
cli.secp256k1_verify_ecdsa(
derivation_path_with(&self.path, derivation_path),
message_hash,
signature,
)
.await
}
}
}
async fn secp256k1_public_key(
&self,
derivation_path: Vec<Vec<u8>>,
) -> Result<[u8; 33], BoxError> {
match self.web3.as_ref() {
Web3SDK::Tee(cli) => {
cli.secp256k1_public_key(derivation_path_with(&self.path, derivation_path))
.await
}
Web3SDK::Web3(Web3Client { client: cli }) => {
cli.secp256k1_public_key(derivation_path_with(&self.path, derivation_path))
.await
}
}
}
}
impl StoreFeatures for BaseCtx {
async fn store_get(&self, path: &Path) -> Result<(bytes::Bytes, ObjectMeta), BoxError> {
self.store.store_get(&self.path, path).await
}
async fn store_list(
&self,
prefix: Option<&Path>,
offset: &Path,
) -> Result<Vec<ObjectMeta>, BoxError> {
self.store.store_list(&self.path, prefix, offset).await
}
async fn store_put(
&self,
path: &Path,
mode: PutMode,
value: bytes::Bytes,
) -> Result<PutResult, BoxError> {
self.store.store_put(&self.path, path, mode, value).await
}
async fn store_rename_if_not_exists(&self, from: &Path, to: &Path) -> Result<(), BoxError> {
self.store
.store_rename_if_not_exists(&self.path, from, to)
.await
}
async fn store_delete(&self, path: &Path) -> Result<(), BoxError> {
self.store.store_delete(&self.path, path).await
}
}
impl CacheFeatures for BaseCtx {
fn cache_contains(&self, key: &str) -> bool {
self.cache.contains(&self.path, key)
}
async fn cache_get<T>(&self, key: &str) -> Result<T, BoxError>
where
T: DeserializeOwned,
{
self.cache.get(&self.path, key).await
}
async fn cache_get_with<T, F>(&self, key: &str, init: F) -> Result<T, BoxError>
where
T: Sized + DeserializeOwned + Serialize + Send,
F: Future<Output = Result<(T, Option<CacheExpiry>), BoxError>> + Send + 'static,
{
self.cache.get_with(&self.path, key, init).await
}
async fn cache_set<T>(&self, key: &str, val: (T, Option<CacheExpiry>))
where
T: Sized + Serialize + Send,
{
self.cache.set(&self.path, key, val).await
}
async fn cache_set_if_not_exists<T>(&self, key: &str, val: (T, Option<CacheExpiry>)) -> bool
where
T: Sized + Serialize + Send,
{
self.cache.set_if_not_exists(&self.path, key, val).await
}
async fn cache_delete(&self, key: &str) -> bool {
self.cache.delete(&self.path, key).await
}
fn cache_raw_iter(
&self,
) -> impl Iterator<Item = (Arc<String>, Arc<(Bytes, Option<CacheExpiry>)>)> {
self.cache.iter(&self.path)
}
}
impl CanisterCaller for BaseCtx {
async fn canister_query<
In: ArgumentEncoder + Send,
Out: CandidType + for<'a> candid::Deserialize<'a>,
>(
&self,
canister: &Principal,
method: &str,
args: In,
) -> Result<Out, BoxError> {
self.web3
.as_ref()
.canister_query(canister, method, args)
.await
}
async fn canister_update<
In: ArgumentEncoder + Send,
Out: CandidType + for<'a> candid::Deserialize<'a>,
>(
&self,
canister: &Principal,
method: &str,
args: In,
) -> Result<Out, BoxError> {
self.web3
.as_ref()
.canister_update(canister, method, args)
.await
}
}
impl HttpFeatures for BaseCtx {
async fn https_call(
&self,
url: &str,
method: http::Method,
headers: Option<http::HeaderMap>,
body: Option<Vec<u8>>,
) -> Result<reqwest::Response, BoxError> {
self.web3
.as_ref()
.https_call(url, method, headers, body)
.await
}
async fn https_signed_call(
&self,
url: &str,
method: http::Method,
message_digest: [u8; 32],
headers: Option<http::HeaderMap>,
body: Option<Vec<u8>>, ) -> Result<reqwest::Response, BoxError> {
self.web3
.as_ref()
.https_signed_call(url, method, message_digest, headers, body)
.await
}
async fn https_signed_rpc<T>(
&self,
endpoint: &str,
method: &str,
args: impl Serialize + Send,
) -> Result<T, BoxError>
where
T: DeserializeOwned,
{
self.web3
.as_ref()
.https_signed_rpc(endpoint, method, args)
.await
}
}