use crate::{
convert::{ConvertRequestLayer, Filter},
ConvertServiceBuilder,
};
use ic_cdk_management_canister::HttpRequestArgs;
use std::convert::Infallible;
use thiserror::Error;
use tower::ServiceBuilder;
use tower_layer::Stack;
pub trait CyclesChargingPolicy {
type Error;
fn cycles_to_charge(&self, request: &HttpRequestArgs, request_cycles_cost: u128) -> u128;
fn charge_cycles(
&self,
request: &HttpRequestArgs,
request_cycles_cost: u128,
) -> Result<u128, Self::Error>;
}
#[derive(Default, Clone)]
pub struct ChargeMyself {}
impl CyclesChargingPolicy for ChargeMyself {
type Error = Infallible;
fn cycles_to_charge(&self, _request: &HttpRequestArgs, _request_cycles_cost: u128) -> u128 {
0
}
fn charge_cycles(
&self,
_request: &HttpRequestArgs,
_request_cycles_cost: u128,
) -> Result<u128, Self::Error> {
Ok(0)
}
}
#[derive(Clone)]
pub struct ChargeCaller<F> {
cycles_to_charge: F,
}
impl<F> ChargeCaller<F>
where
F: Fn(&HttpRequestArgs, u128) -> u128,
{
pub fn new(cycles_to_charge: F) -> Self {
ChargeCaller { cycles_to_charge }
}
}
impl<F> CyclesChargingPolicy for ChargeCaller<F>
where
F: Fn(&HttpRequestArgs, u128) -> u128,
{
type Error = ChargeCallerError;
fn cycles_to_charge(&self, request: &HttpRequestArgs, request_cycles_cost: u128) -> u128 {
(self.cycles_to_charge)(request, request_cycles_cost)
}
fn charge_cycles(
&self,
request: &HttpRequestArgs,
request_cycles_cost: u128,
) -> Result<u128, Self::Error> {
let cycles_to_charge = self.cycles_to_charge(request, request_cycles_cost);
if cycles_to_charge > 0 {
let cycles_available = ic_cdk::api::msg_cycles_available();
if cycles_available < cycles_to_charge {
return Err(ChargeCallerError::InsufficientCyclesError {
expected: cycles_to_charge,
received: cycles_available,
});
}
let cycles_received = ic_cdk::api::msg_cycles_accept(cycles_to_charge);
assert_eq!(
cycles_received, cycles_to_charge,
"Expected to receive {cycles_to_charge}, but got {cycles_received}"
);
}
Ok(cycles_to_charge)
}
}
#[derive(Error, Clone, Debug, PartialEq, Eq)]
pub enum ChargeCallerError {
#[error("insufficient cycles (expected {expected:?}, received {received:?})")]
InsufficientCyclesError {
expected: u128,
received: u128,
},
}
#[derive(Clone, Debug)]
pub struct CyclesAccounting<ChargingPolicy> {
charging_policy: ChargingPolicy,
}
impl<ChargingPolicy> CyclesAccounting<ChargingPolicy> {
pub fn new(charging_policy: ChargingPolicy) -> Self {
Self { charging_policy }
}
}
impl<ChargingPolicy> Filter<HttpRequestArgs> for CyclesAccounting<ChargingPolicy>
where
ChargingPolicy: CyclesChargingPolicy,
{
type Error = ChargingPolicy::Error;
fn filter(&mut self, request: HttpRequestArgs) -> Result<HttpRequestArgs, Self::Error> {
let cycles_to_attach = ic_cdk_management_canister::cost_http_request(&request);
self.charging_policy
.charge_cycles(&request, cycles_to_attach)?;
Ok(request)
}
}
pub trait CyclesAccountingServiceBuilder<L> {
fn cycles_accounting<C>(
self,
charging: C,
) -> ServiceBuilder<Stack<ConvertRequestLayer<CyclesAccounting<C>>, L>>;
}
impl<L> CyclesAccountingServiceBuilder<L> for ServiceBuilder<L> {
fn cycles_accounting<C>(
self,
charging: C,
) -> ServiceBuilder<Stack<ConvertRequestLayer<CyclesAccounting<C>>, L>> {
self.convert_request(CyclesAccounting::new(charging))
}
}