use crate::api::call::{call_with_payment128, CallResult};
use candid::{
parser::types::FuncMode,
types::{Function, Serializer, Type},
CandidType, Principal,
};
use core::hash::Hash;
use serde::{Deserialize, Serialize};
#[derive(Deserialize, Debug, PartialEq, Eq, Clone)]
pub struct TransformFunc(pub candid::Func);
impl CandidType for TransformFunc {
fn _ty() -> Type {
Type::Func(Function {
modes: vec![FuncMode::Query],
args: vec![TransformArgs::ty()],
rets: vec![HttpResponse::ty()],
})
}
fn idl_serialize<S: Serializer>(&self, serializer: S) -> Result<(), S::Error> {
serializer.serialize_function(self.0.principal.as_slice(), &self.0.method)
}
}
#[derive(CandidType, Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct TransformArgs {
pub response: HttpResponse,
#[serde(with = "serde_bytes")]
pub context: Vec<u8>,
}
#[derive(CandidType, Clone, Debug, Deserialize, PartialEq, Eq)]
pub struct TransformContext {
pub function: TransformFunc,
#[serde(with = "serde_bytes")]
pub context: Vec<u8>,
}
impl TransformContext {
pub fn new<T>(func: T, context: Vec<u8>) -> Self
where
T: Fn(TransformArgs) -> HttpResponse,
{
Self {
function: TransformFunc(candid::Func {
principal: crate::id(),
method: get_function_name(func).to_string(),
}),
context,
}
}
}
fn get_function_name<F>(_: F) -> &'static str {
let full_name = std::any::type_name::<F>();
match full_name.rfind(':') {
Some(index) => &full_name[index + 1..],
None => full_name,
}
}
#[derive(
CandidType, Serialize, Deserialize, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Default,
)]
pub struct HttpHeader {
pub name: String,
pub value: String,
}
#[derive(
CandidType, Serialize, Deserialize, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy,
)]
pub enum HttpMethod {
#[serde(rename = "get")]
GET,
#[serde(rename = "post")]
POST,
#[serde(rename = "head")]
HEAD,
}
#[derive(CandidType, Deserialize, Debug, PartialEq, Eq, Clone)]
pub struct CanisterHttpRequestArgument {
pub url: String,
pub max_response_bytes: Option<u64>,
pub method: HttpMethod,
pub headers: Vec<HttpHeader>,
pub body: Option<Vec<u8>>,
pub transform: Option<TransformContext>,
}
#[derive(
CandidType, Serialize, Deserialize, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Default,
)]
pub struct HttpResponse {
pub status: candid::Nat,
pub headers: Vec<HttpHeader>,
pub body: Vec<u8>,
}
pub async fn http_request(arg: CanisterHttpRequestArgument) -> CallResult<(HttpResponse,)> {
let cycles = http_request_required_cycles(&arg);
call_with_payment128(
Principal::management_canister(),
"http_request",
(arg,),
cycles,
)
.await
}
pub async fn http_request_with_cycles(
arg: CanisterHttpRequestArgument,
cycles: u128,
) -> CallResult<(HttpResponse,)> {
call_with_payment128(
Principal::management_canister(),
"http_request",
(arg,),
cycles,
)
.await
}
fn http_request_required_cycles(arg: &CanisterHttpRequestArgument) -> u128 {
let max_response_bytes = match arg.max_response_bytes {
Some(ref n) => *n as u128,
None => 2 * 1024 * 1024u128, };
let arg_raw = candid::utils::encode_args((arg,)).expect("Failed to encode arguments.");
400_000_000u128 + 100_000u128 * (arg_raw.len() as u128 + 12 + max_response_bytes)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn required_cycles_some_max() {
let url = "https://example.com".to_string();
let arg = CanisterHttpRequestArgument {
url,
max_response_bytes: Some(3000),
method: HttpMethod::GET,
headers: vec![],
body: None,
transform: None,
};
assert_eq!(http_request_required_cycles(&arg), 718500000u128);
}
#[test]
fn required_cycles_none_max() {
let url = "https://example.com".to_string();
let arg = CanisterHttpRequestArgument {
url,
max_response_bytes: None,
method: HttpMethod::GET,
headers: vec![],
body: None,
transform: None,
};
assert_eq!(http_request_required_cycles(&arg), 210132900000u128);
}
#[test]
fn get_function_name_work() {
fn func() {}
assert_eq!(get_function_name(func), "func");
}
}