Skip to main content

canic_host/candid_endpoints/
mod.rs

1use candid::{
2    TypeEnv,
3    types::{FuncMode, Function, Label, Type, TypeInner},
4};
5use candid_parser::utils::CandidSource;
6use serde::Serialize;
7use thiserror::Error as ThisError;
8
9///
10/// CandidEndpointError
11///
12
13#[derive(Debug, ThisError)]
14pub enum CandidEndpointError {
15    #[error("canister interface did not contain a service block")]
16    MissingService,
17
18    #[error("failed to parse Candid interface: {0}")]
19    InvalidCandid(String),
20}
21
22///
23/// EndpointEntry
24///
25
26#[derive(Clone, Debug, Eq, PartialEq, Serialize)]
27pub struct EndpointEntry {
28    pub name: String,
29    pub candid: String,
30    pub modes: Vec<EndpointMode>,
31    pub arguments: Vec<EndpointType>,
32    pub returns: Vec<EndpointType>,
33}
34
35///
36/// EndpointMode
37///
38
39#[derive(Clone, Debug, Eq, PartialEq, Serialize)]
40#[serde(rename_all = "snake_case")]
41pub enum EndpointMode {
42    Query,
43    CompositeQuery,
44    Oneway,
45}
46
47///
48/// EndpointCardinality
49///
50
51#[derive(Clone, Debug, Eq, PartialEq, Serialize)]
52#[serde(rename_all = "snake_case")]
53pub enum EndpointCardinality {
54    Single,
55    Optional,
56    Many,
57}
58
59///
60/// EndpointType
61///
62
63#[derive(Clone, Debug, Eq, PartialEq, Serialize)]
64#[serde(tag = "kind", rename_all = "snake_case")]
65pub enum EndpointType {
66    Primitive {
67        candid: String,
68        cardinality: EndpointCardinality,
69        name: String,
70    },
71    Named {
72        candid: String,
73        cardinality: EndpointCardinality,
74        name: String,
75        #[serde(skip_serializing_if = "Option::is_none")]
76        resolved: Option<Box<Self>>,
77    },
78    Optional {
79        candid: String,
80        cardinality: EndpointCardinality,
81        inner: Box<Self>,
82    },
83    Vector {
84        candid: String,
85        cardinality: EndpointCardinality,
86        inner: Box<Self>,
87    },
88    Record {
89        candid: String,
90        cardinality: EndpointCardinality,
91        fields: Vec<EndpointField>,
92    },
93    Variant {
94        candid: String,
95        cardinality: EndpointCardinality,
96        cases: Vec<EndpointField>,
97    },
98    Function {
99        candid: String,
100        cardinality: EndpointCardinality,
101        modes: Vec<EndpointMode>,
102        arguments: Vec<Self>,
103        returns: Vec<Self>,
104    },
105    Service {
106        candid: String,
107        cardinality: EndpointCardinality,
108        methods: Vec<EndpointServiceMethod>,
109    },
110    Class {
111        candid: String,
112        cardinality: EndpointCardinality,
113        initializers: Vec<Self>,
114        service: Box<Self>,
115    },
116}
117
118///
119/// EndpointField
120///
121
122#[derive(Clone, Debug, Eq, PartialEq, Serialize)]
123pub struct EndpointField {
124    pub label: String,
125    pub id: u32,
126    pub ty: EndpointType,
127}
128
129///
130/// EndpointServiceMethod
131///
132
133#[derive(Clone, Debug, Eq, PartialEq, Serialize)]
134pub struct EndpointServiceMethod {
135    pub name: String,
136    pub ty: EndpointType,
137}
138
139/// Parse a Candid service interface into structured endpoint descriptions.
140pub fn parse_candid_service_endpoints(
141    candid: &str,
142) -> Result<Vec<EndpointEntry>, CandidEndpointError> {
143    let (env, actor) = CandidSource::Text(candid)
144        .load()
145        .map_err(|err| CandidEndpointError::InvalidCandid(err.to_string()))?;
146    let Some(actor) = actor else {
147        return Err(CandidEndpointError::MissingService);
148    };
149    let service = env
150        .as_service(&actor)
151        .map_err(|_| CandidEndpointError::MissingService)?;
152    service
153        .iter()
154        .map(|(name, ty)| endpoint_entry(&env, name, ty))
155        .collect()
156}
157
158fn endpoint_entry(
159    env: &TypeEnv,
160    name: &str,
161    ty: &Type,
162) -> Result<EndpointEntry, CandidEndpointError> {
163    let function = env
164        .as_func(ty)
165        .map_err(|err| CandidEndpointError::InvalidCandid(err.to_string()))?;
166    Ok(EndpointEntry {
167        name: name.to_string(),
168        candid: format!("{} : {};", render_candid_method_name(name), function),
169        modes: endpoint_modes(function),
170        arguments: endpoint_types(env, &function.args),
171        returns: endpoint_types(env, &function.rets),
172    })
173}
174
175fn endpoint_types(env: &TypeEnv, types: &[Type]) -> Vec<EndpointType> {
176    types
177        .iter()
178        .map(|ty| endpoint_type(env, ty, &mut Vec::new()))
179        .collect()
180}
181
182fn endpoint_type(env: &TypeEnv, ty: &Type, named_stack: &mut Vec<String>) -> EndpointType {
183    match ty.as_ref() {
184        TypeInner::Null => primitive_type(ty, "null"),
185        TypeInner::Bool => primitive_type(ty, "bool"),
186        TypeInner::Nat => primitive_type(ty, "nat"),
187        TypeInner::Int => primitive_type(ty, "int"),
188        TypeInner::Nat8 => primitive_type(ty, "nat8"),
189        TypeInner::Nat16 => primitive_type(ty, "nat16"),
190        TypeInner::Nat32 => primitive_type(ty, "nat32"),
191        TypeInner::Nat64 => primitive_type(ty, "nat64"),
192        TypeInner::Int8 => primitive_type(ty, "int8"),
193        TypeInner::Int16 => primitive_type(ty, "int16"),
194        TypeInner::Int32 => primitive_type(ty, "int32"),
195        TypeInner::Int64 => primitive_type(ty, "int64"),
196        TypeInner::Float32 => primitive_type(ty, "float32"),
197        TypeInner::Float64 => primitive_type(ty, "float64"),
198        TypeInner::Text => primitive_type(ty, "text"),
199        TypeInner::Reserved => primitive_type(ty, "reserved"),
200        TypeInner::Empty => primitive_type(ty, "empty"),
201        TypeInner::Principal => primitive_type(ty, "principal"),
202        TypeInner::Future => primitive_type(ty, "future"),
203        TypeInner::Unknown => primitive_type(ty, "unknown"),
204        TypeInner::Knot(id) => EndpointType::Named {
205            candid: ty.to_string(),
206            cardinality: EndpointCardinality::Single,
207            name: id.to_string(),
208            resolved: None,
209        },
210        TypeInner::Var(name) => named_type(env, ty, name, named_stack),
211        TypeInner::Opt(inner) => EndpointType::Optional {
212            candid: ty.to_string(),
213            cardinality: EndpointCardinality::Optional,
214            inner: Box::new(endpoint_type(env, inner, named_stack)),
215        },
216        TypeInner::Vec(inner) => EndpointType::Vector {
217            candid: ty.to_string(),
218            cardinality: EndpointCardinality::Many,
219            inner: Box::new(endpoint_type(env, inner, named_stack)),
220        },
221        TypeInner::Record(fields) => EndpointType::Record {
222            candid: ty.to_string(),
223            cardinality: EndpointCardinality::Single,
224            fields: endpoint_fields(env, fields, named_stack),
225        },
226        TypeInner::Variant(fields) => EndpointType::Variant {
227            candid: ty.to_string(),
228            cardinality: EndpointCardinality::Single,
229            cases: endpoint_fields(env, fields, named_stack),
230        },
231        TypeInner::Func(function) => EndpointType::Function {
232            candid: ty.to_string(),
233            cardinality: EndpointCardinality::Single,
234            modes: endpoint_modes(function),
235            arguments: endpoint_types(env, &function.args),
236            returns: endpoint_types(env, &function.rets),
237        },
238        TypeInner::Service(methods) => EndpointType::Service {
239            candid: ty.to_string(),
240            cardinality: EndpointCardinality::Single,
241            methods: methods
242                .iter()
243                .map(|(name, ty)| EndpointServiceMethod {
244                    name: name.clone(),
245                    ty: endpoint_type(env, ty, named_stack),
246                })
247                .collect(),
248        },
249        TypeInner::Class(initializers, service) => EndpointType::Class {
250            candid: ty.to_string(),
251            cardinality: EndpointCardinality::Single,
252            initializers: endpoint_types(env, initializers),
253            service: Box::new(endpoint_type(env, service, named_stack)),
254        },
255    }
256}
257
258fn primitive_type(ty: &Type, name: &str) -> EndpointType {
259    EndpointType::Primitive {
260        candid: ty.to_string(),
261        cardinality: EndpointCardinality::Single,
262        name: name.to_string(),
263    }
264}
265
266fn named_type(env: &TypeEnv, ty: &Type, name: &str, named_stack: &mut Vec<String>) -> EndpointType {
267    let (cardinality, resolved) = if named_stack.iter().any(|seen| seen == name) {
268        (EndpointCardinality::Single, None)
269    } else if let Ok(resolved) = env.find_type(name) {
270        named_stack.push(name.to_string());
271        let cardinality = endpoint_cardinality(env, resolved, named_stack);
272        let resolved = endpoint_type(env, resolved, named_stack);
273        named_stack.pop();
274        (cardinality, Some(Box::new(resolved)))
275    } else {
276        (EndpointCardinality::Single, None)
277    };
278    EndpointType::Named {
279        candid: ty.to_string(),
280        cardinality,
281        name: name.to_string(),
282        resolved,
283    }
284}
285
286fn endpoint_cardinality(
287    env: &TypeEnv,
288    ty: &Type,
289    named_stack: &mut Vec<String>,
290) -> EndpointCardinality {
291    match ty.as_ref() {
292        TypeInner::Opt(_) => EndpointCardinality::Optional,
293        TypeInner::Vec(_) => EndpointCardinality::Many,
294        TypeInner::Var(name) if !named_stack.iter().any(|seen| seen == name) => {
295            if let Ok(resolved) = env.find_type(name) {
296                named_stack.push(name.clone());
297                let cardinality = endpoint_cardinality(env, resolved, named_stack);
298                named_stack.pop();
299                cardinality
300            } else {
301                EndpointCardinality::Single
302            }
303        }
304        _ => EndpointCardinality::Single,
305    }
306}
307
308fn endpoint_fields(
309    env: &TypeEnv,
310    fields: &[candid::types::Field],
311    named_stack: &mut Vec<String>,
312) -> Vec<EndpointField> {
313    fields
314        .iter()
315        .map(|field| EndpointField {
316            label: field_label(&field.id),
317            id: field.id.get_id(),
318            ty: endpoint_type(env, &field.ty, named_stack),
319        })
320        .collect()
321}
322
323fn field_label(label: &Label) -> String {
324    match label {
325        Label::Named(name) => name.clone(),
326        Label::Id(id) | Label::Unnamed(id) => id.to_string(),
327    }
328}
329
330fn endpoint_modes(function: &Function) -> Vec<EndpointMode> {
331    function
332        .modes
333        .iter()
334        .map(|mode| match mode {
335            FuncMode::Query => EndpointMode::Query,
336            FuncMode::CompositeQuery => EndpointMode::CompositeQuery,
337            FuncMode::Oneway => EndpointMode::Oneway,
338        })
339        .collect()
340}
341
342/// Render a Candid method name, quoting identifiers that Candid requires quoted.
343#[must_use]
344pub fn render_candid_method_name(name: &str) -> String {
345    if is_candid_identifier(name) && !is_candid_reserved_word(name) {
346        name.to_string()
347    } else {
348        format!("{name:?}")
349    }
350}
351
352fn is_candid_identifier(name: &str) -> bool {
353    let mut chars = name.chars();
354    let Some(first) = chars.next() else {
355        return false;
356    };
357    (first.is_ascii_alphabetic() || first == '_')
358        && chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_')
359}
360
361fn is_candid_reserved_word(name: &str) -> bool {
362    matches!(
363        name,
364        "blob"
365            | "bool"
366            | "composite_query"
367            | "empty"
368            | "false"
369            | "float32"
370            | "float64"
371            | "func"
372            | "import"
373            | "int"
374            | "int8"
375            | "int16"
376            | "int32"
377            | "int64"
378            | "nat"
379            | "nat8"
380            | "nat16"
381            | "nat32"
382            | "nat64"
383            | "null"
384            | "oneway"
385            | "opt"
386            | "principal"
387            | "query"
388            | "record"
389            | "reserved"
390            | "service"
391            | "text"
392            | "true"
393            | "type"
394            | "variant"
395            | "vec"
396    )
397}
398
399#[cfg(test)]
400mod tests;