cargo_lambda_metadata/cargo/
watch.rs

1use cargo_options::Run;
2use clap::Args;
3use matchit::{InsertError, MatchError, Router};
4use serde::{
5    Deserialize, Serialize,
6    de::{Error, Visitor},
7    ser::SerializeSeq,
8};
9use serde_json::{Value, json};
10use std::{collections::HashMap, path::PathBuf};
11
12use crate::{
13    cargo::{count_common_options, serialize_common_options},
14    env::{EnvOptions, Environment},
15    error::MetadataError,
16    lambda::Timeout,
17};
18
19use cargo_lambda_remote::tls::TlsOptions;
20
21#[cfg(windows)]
22const DEFAULT_INVOKE_ADDRESS: &str = "127.0.0.1";
23
24#[cfg(not(windows))]
25const DEFAULT_INVOKE_ADDRESS: &str = "::";
26
27const DEFAULT_INVOKE_PORT: u16 = 9000;
28
29#[derive(Args, Clone, Debug, Default, Deserialize)]
30#[command(
31    name = "watch",
32    visible_alias = "start",
33    after_help = "Full command documentation: https://www.cargo-lambda.info/commands/watch.html"
34)]
35pub struct Watch {
36    /// Ignore any code changes, and don't reload the function automatically
37    #[arg(long, visible_alias = "no-reload")]
38    #[serde(default)]
39    pub ignore_changes: bool,
40
41    /// Start the Lambda runtime APIs without starting the function.
42    /// This is useful if you start (and debug) your function in your IDE.
43    #[arg(long)]
44    #[serde(default)]
45    pub only_lambda_apis: bool,
46
47    #[arg(short = 'a', long, default_value = DEFAULT_INVOKE_ADDRESS)]
48    #[serde(default = "default_invoke_address")]
49    /// Address where users send invoke requests
50    pub invoke_address: String,
51
52    /// Address port where users send invoke requests
53    #[arg(short = 'P', long, default_value_t = DEFAULT_INVOKE_PORT)]
54    #[serde(default = "default_invoke_port")]
55    pub invoke_port: u16,
56
57    /// Print OpenTelemetry traces after each function invocation
58    #[arg(long)]
59    #[serde(default)]
60    pub print_traces: bool,
61
62    /// Wait for the first invocation to compile the function
63    #[arg(long, short)]
64    #[serde(default)]
65    pub wait: bool,
66
67    /// Disable the default CORS configuration
68    #[arg(long)]
69    #[serde(default)]
70    pub disable_cors: bool,
71
72    /// How long the invoke request waits for a response
73    #[arg(long)]
74    #[serde(default)]
75    pub timeout: Option<Timeout>,
76
77    #[command(flatten)]
78    #[serde(flatten)]
79    pub cargo_opts: Run,
80
81    #[command(flatten)]
82    #[serde(flatten)]
83    pub env_options: EnvOptions,
84
85    #[command(flatten)]
86    #[serde(flatten)]
87    pub tls_options: TlsOptions,
88
89    #[arg(skip)]
90    #[serde(default)]
91    pub router: Option<FunctionRouter>,
92}
93
94impl Watch {
95    pub fn manifest_path(&self) -> PathBuf {
96        self.cargo_opts
97            .manifest_path
98            .clone()
99            .unwrap_or_else(|| "Cargo.toml".into())
100    }
101
102    /// Returns the package name if there is only one package in the list of `packages`,
103    /// otherwise None.
104    pub fn package(&self) -> Option<String> {
105        if self.cargo_opts.packages.len() > 1 {
106            return None;
107        }
108        self.cargo_opts.packages.first().map(|s| s.to_string())
109    }
110
111    pub fn lambda_environment(
112        &self,
113        base: &HashMap<String, String>,
114    ) -> Result<Environment, MetadataError> {
115        self.env_options.lambda_environment(base)
116    }
117}
118
119impl Serialize for Watch {
120    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
121    where
122        S: serde::Serializer,
123    {
124        use serde::ser::SerializeStruct;
125
126        // Count non-empty fields
127        let field_count = self.ignore_changes as usize
128            + self.only_lambda_apis as usize
129            + !self.invoke_address.is_empty() as usize
130            + (self.invoke_port != 0) as usize
131            + self.print_traces as usize
132            + self.wait as usize
133            + self.disable_cors as usize
134            + self.timeout.is_some() as usize
135            + self.router.is_some() as usize
136            + self.cargo_opts.manifest_path.is_some() as usize
137            + self.cargo_opts.release as usize
138            + self.cargo_opts.ignore_rust_version as usize
139            + self.cargo_opts.unit_graph as usize
140            + !self.cargo_opts.packages.is_empty() as usize
141            + !self.cargo_opts.bin.is_empty() as usize
142            + !self.cargo_opts.example.is_empty() as usize
143            + !self.cargo_opts.args.is_empty() as usize
144            + count_common_options(&self.cargo_opts.common)
145            + self.env_options.count_fields()
146            + self.tls_options.count_fields();
147
148        let mut state = serializer.serialize_struct("Watch", field_count)?;
149
150        // Only serialize bool fields that are true
151        if self.ignore_changes {
152            state.serialize_field("ignore_changes", &true)?;
153        }
154        if self.only_lambda_apis {
155            state.serialize_field("only_lambda_apis", &true)?;
156        }
157        if !self.invoke_address.is_empty() {
158            state.serialize_field("invoke_address", &self.invoke_address)?;
159        }
160        if self.invoke_port != 0 {
161            state.serialize_field("invoke_port", &self.invoke_port)?;
162        }
163        if self.print_traces {
164            state.serialize_field("print_traces", &true)?;
165        }
166        if self.wait {
167            state.serialize_field("wait", &true)?;
168        }
169        if self.disable_cors {
170            state.serialize_field("disable_cors", &true)?;
171        }
172
173        // Only serialize Some values for Options
174        if let Some(timeout) = &self.timeout {
175            state.serialize_field("timeout", timeout)?;
176        }
177        if let Some(router) = &self.router {
178            state.serialize_field("router", router)?;
179        }
180
181        // Flatten the fields from cargo_opts and env_options
182        self.env_options.serialize_fields::<S>(&mut state)?;
183        self.tls_options.serialize_fields::<S>(&mut state)?;
184
185        if let Some(manifest_path) = &self.cargo_opts.manifest_path {
186            state.serialize_field("manifest_path", manifest_path)?;
187        }
188        if self.cargo_opts.release {
189            state.serialize_field("release", &true)?;
190        }
191        if self.cargo_opts.ignore_rust_version {
192            state.serialize_field("ignore_rust_version", &true)?;
193        }
194        if self.cargo_opts.unit_graph {
195            state.serialize_field("unit_graph", &true)?;
196        }
197        if !self.cargo_opts.packages.is_empty() {
198            state.serialize_field("packages", &self.cargo_opts.packages)?;
199        }
200        if !self.cargo_opts.bin.is_empty() {
201            state.serialize_field("bin", &self.cargo_opts.bin)?;
202        }
203        if !self.cargo_opts.example.is_empty() {
204            state.serialize_field("example", &self.cargo_opts.example)?;
205        }
206        if !self.cargo_opts.args.is_empty() {
207            state.serialize_field("args", &self.cargo_opts.args)?;
208        }
209        serialize_common_options::<S>(&mut state, &self.cargo_opts.common)?;
210
211        state.end()
212    }
213}
214
215fn default_invoke_address() -> String {
216    DEFAULT_INVOKE_ADDRESS.to_string()
217}
218
219fn default_invoke_port() -> u16 {
220    DEFAULT_INVOKE_PORT
221}
222
223#[derive(Clone, Debug, Default, Deserialize, Serialize)]
224pub struct WatchConfig {
225    pub router: Option<FunctionRouter>,
226}
227
228#[derive(Clone, Debug, Default)]
229pub struct FunctionRouter {
230    inner: Router<FunctionRoutes>,
231    pub(crate) raw: Vec<(String, FunctionRoutes)>,
232}
233
234impl FunctionRouter {
235    pub fn at(
236        &self,
237        path: &str,
238        method: &str,
239    ) -> Result<(String, HashMap<String, String>), MatchError> {
240        let matched = self.inner.at(path)?;
241        let function = matched.value.at(method).ok_or(MatchError::NotFound)?;
242
243        let params = matched
244            .params
245            .iter()
246            .map(|(k, v)| (k.to_string(), v.to_string()))
247            .collect();
248
249        Ok((function.to_string(), params))
250    }
251
252    pub fn insert(&mut self, path: &str, routes: FunctionRoutes) -> Result<(), InsertError> {
253        self.inner.insert(path, routes)
254    }
255}
256
257#[derive(Clone, Debug, PartialEq)]
258pub enum FunctionRoutes {
259    Single(String),
260    Multiple(HashMap<String, String>),
261}
262
263impl FunctionRoutes {
264    pub fn at(&self, method: &str) -> Option<&str> {
265        match self {
266            FunctionRoutes::Single(function) => Some(function),
267            FunctionRoutes::Multiple(routes) => routes.get(method).map(|s| s.as_str()),
268        }
269    }
270}
271
272struct FunctionRouterVisitor;
273
274impl<'de> Visitor<'de> for FunctionRouterVisitor {
275    type Value = FunctionRouter;
276
277    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
278        formatter.write_str("a map or sequence of function routes")
279    }
280
281    fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
282    where
283        A: serde::de::MapAccess<'de>,
284    {
285        let routes: HashMap<String, FunctionRoutes> =
286            Deserialize::deserialize(serde::de::value::MapAccessDeserializer::new(map))?;
287        let mut inner = Router::new();
288
289        for (path, route) in &routes {
290            inner.insert(path, route.clone()).map_err(|e| {
291                serde::de::Error::custom(format!("Failed to insert route {path}: {e}"))
292            })?;
293        }
294
295        let raw: Vec<(String, FunctionRoutes)> = routes.into_iter().collect();
296        Ok(FunctionRouter { inner, raw })
297    }
298
299    fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
300    where
301        A: serde::de::SeqAccess<'de>,
302    {
303        let raw: Vec<(String, FunctionRoutes)> =
304            Deserialize::deserialize(serde::de::value::SeqAccessDeserializer::new(seq))?;
305        let mut inner = Router::new();
306
307        for (path, route) in &raw {
308            inner.insert(path, route.clone()).map_err(|e| {
309                serde::de::Error::custom(format!("Failed to insert route {path}: {e}"))
310            })?;
311        }
312
313        Ok(FunctionRouter { inner, raw })
314    }
315}
316
317impl<'de> Deserialize<'de> for FunctionRouter {
318    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
319    where
320        D: serde::Deserializer<'de>,
321    {
322        deserializer.deserialize_any(FunctionRouterVisitor)
323    }
324}
325
326impl Serialize for FunctionRouter {
327    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
328    where
329        S: serde::Serializer,
330    {
331        self.raw.serialize(serializer)
332    }
333}
334
335impl<'de> Deserialize<'de> for FunctionRoutes {
336    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
337    where
338        D: serde::Deserializer<'de>,
339    {
340        let value = Value::deserialize(deserializer)?;
341        match value {
342            Value::String(s) => Ok(FunctionRoutes::Single(s)),
343            Value::Array(arr) => {
344                let mut routes = HashMap::new();
345                for item in arr {
346                    let obj = item.as_object().ok_or_else(|| {
347                        Error::custom("Array items must be objects with method and function fields")
348                    })?;
349
350                    let method = obj
351                        .get("method")
352                        .and_then(|m| m.as_str())
353                        .ok_or_else(|| Error::custom("Missing or invalid method field"))?;
354
355                    let function = obj
356                        .get("function")
357                        .and_then(|f| f.as_str())
358                        .ok_or_else(|| Error::custom("Missing or invalid function field"))?;
359
360                    routes.insert(method.to_string(), function.to_string());
361                }
362                Ok(FunctionRoutes::Multiple(routes))
363            }
364            _ => Err(Error::custom(
365                "Function routes must be either a string or an array of objects with method and function fields",
366            )),
367        }
368    }
369}
370
371impl Serialize for FunctionRoutes {
372    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
373    where
374        S: serde::Serializer,
375    {
376        match self {
377            FunctionRoutes::Single(function) => function.serialize(serializer),
378            FunctionRoutes::Multiple(routes) => {
379                let mut seq = serializer.serialize_seq(Some(routes.len()))?;
380                for (method, function) in routes {
381                    let mut map = serde_json::Map::new();
382                    map.insert("method".to_string(), json!(method));
383                    map.insert("function".to_string(), json!(function));
384                    seq.serialize_element(&Value::Object(map))?;
385                }
386                seq.end()
387            }
388        }
389    }
390}
391
392#[cfg(test)]
393mod tests {
394
395    use cargo_options::CommonOptions;
396    use serde_json::{Value, json};
397    use std::path::PathBuf;
398
399    use super::*;
400
401    #[test]
402    fn test_router_deserialize() {
403        let router: FunctionRouter = toml::from_str(
404            r#"
405            "/api/v1/users" = [
406                { function = "get_user", method = "GET" },
407                { function = "create_user", method = "POST" }
408            ]
409            "/api/v1/all_methods" = "all_methods"
410        "#,
411        )
412        .unwrap();
413
414        assert_eq!(
415            router.inner.at("/api/v1/users").unwrap().value,
416            &FunctionRoutes::Multiple(HashMap::from([
417                ("GET".to_string(), "get_user".to_string()),
418                ("POST".to_string(), "create_user".to_string()),
419            ]))
420        );
421
422        assert_eq!(
423            router.inner.at("/api/v1/all_methods").unwrap().value,
424            &FunctionRoutes::Single("all_methods".to_string())
425        );
426    }
427
428    #[test]
429    fn test_router_get() {
430        let router = FunctionRouter::default();
431        assert_eq!(router.at("/api/v1/users", "GET"), Err(MatchError::NotFound));
432
433        let mut inner = Router::new();
434        inner
435            .insert(
436                "/api/v1/users",
437                FunctionRoutes::Single("user_handler".to_string()),
438            )
439            .unwrap();
440        let router = FunctionRouter {
441            inner,
442            ..Default::default()
443        };
444        assert_eq!(
445            router.at("/api/v1/users", "GET"),
446            Ok(("user_handler".to_string(), HashMap::new()))
447        );
448        assert_eq!(
449            router.at("/api/v1/users", "POST"),
450            Ok(("user_handler".to_string(), HashMap::new()))
451        );
452
453        let mut inner = Router::new();
454        inner
455            .insert(
456                "/api/v1/users",
457                FunctionRoutes::Multiple(HashMap::from([
458                    ("GET".to_string(), "get_user".to_string()),
459                    ("POST".to_string(), "create_user".to_string()),
460                ])),
461            )
462            .unwrap();
463        let router = FunctionRouter {
464            inner,
465            ..Default::default()
466        };
467        assert_eq!(
468            router.at("/api/v1/users", "GET"),
469            Ok(("get_user".to_string(), HashMap::new()))
470        );
471        assert_eq!(
472            router.at("/api/v1/users", "POST"),
473            Ok(("create_user".to_string(), HashMap::new()))
474        );
475        assert_eq!(router.at("/api/v1/users", "PUT"), Err(MatchError::NotFound));
476
477        let mut inner = Router::new();
478        inner
479            .insert(
480                "/api/v1/users/{id}",
481                FunctionRoutes::Single("user_handler".to_string()),
482            )
483            .unwrap();
484        let router = FunctionRouter {
485            inner,
486            ..Default::default()
487        };
488
489        let (function, params) = router.at("/api/v1/users/1", "GET").unwrap();
490        assert_eq!(function, "user_handler");
491        assert_eq!(params, HashMap::from([("id".to_string(), "1".to_string())]));
492    }
493
494    #[test]
495    fn test_router_serialize() {
496        let config = r#"
497            "/api/v1/users" = [
498                { function = "get_user", method = "GET" },
499                { function = "create_user", method = "POST" }
500            ]
501            "/api/v1/all_methods" = "all_methods"
502        "#;
503        let router: FunctionRouter = toml::from_str(config).unwrap();
504
505        let json = serde_json::to_value(&router).unwrap();
506
507        let new_router: FunctionRouter = serde_json::from_value(json).unwrap();
508        assert_eq!(new_router.raw, router.raw);
509
510        assert_eq!(
511            new_router.inner.at("/api/v1/users").unwrap().value,
512            &FunctionRoutes::Multiple(HashMap::from([
513                ("GET".to_string(), "get_user".to_string()),
514                ("POST".to_string(), "create_user".to_string()),
515            ]))
516        );
517
518        assert_eq!(
519            new_router.inner.at("/api/v1/all_methods").unwrap().value,
520            &FunctionRoutes::Single("all_methods".to_string())
521        );
522    }
523
524    #[test]
525    fn test_watch_serialization() {
526        let watch = Watch {
527            invoke_address: "127.0.0.1".to_string(),
528            invoke_port: 9000,
529            env_options: EnvOptions {
530                env_file: Some(PathBuf::from("/tmp/env")),
531                env_var: Some(vec!["FOO=BAR".to_string()]),
532            },
533            tls_options: TlsOptions::new(
534                Some(PathBuf::from("/tmp/cert.pem")),
535                Some(PathBuf::from("/tmp/key.pem")),
536                Some(PathBuf::from("/tmp/ca.pem")),
537            ),
538            cargo_opts: Run {
539                common: CommonOptions {
540                    quiet: false,
541                    jobs: None,
542                    keep_going: false,
543                    profile: None,
544                    features: vec!["feature1".to_string()],
545                    all_features: false,
546                    no_default_features: true,
547                    target: vec!["x86_64-unknown-linux-gnu".to_string()],
548                    target_dir: Some(PathBuf::from("/tmp/target")),
549                    message_format: vec!["json".to_string()],
550                    verbose: 1,
551                    color: Some("auto".to_string()),
552                    frozen: true,
553                    locked: true,
554                    offline: true,
555                    config: vec!["config.toml".to_string()],
556                    unstable_flags: vec!["flag1".to_string()],
557                    timings: None,
558                },
559                manifest_path: None,
560                release: false,
561                ignore_rust_version: false,
562                unit_graph: false,
563                packages: vec![],
564                bin: vec![],
565                example: vec![],
566                args: vec![],
567            },
568            ..Default::default()
569        };
570
571        let json = serde_json::to_value(&watch).unwrap();
572        assert_eq!(json["invoke_address"], "127.0.0.1");
573        assert_eq!(json["invoke_port"], 9000);
574        assert_eq!(json["env_file"], "/tmp/env");
575        assert_eq!(json["env_var"], json!(["FOO=BAR"]));
576        assert_eq!(json["tls_cert"], "/tmp/cert.pem");
577        assert_eq!(json["tls_key"], "/tmp/key.pem");
578        assert_eq!(json["tls_ca"], "/tmp/ca.pem");
579        assert_eq!(json["features"], json!(["feature1"]));
580        assert_eq!(json["no_default_features"], true);
581        assert_eq!(json["target"], json!(["x86_64-unknown-linux-gnu"]));
582        assert_eq!(json["target_dir"], "/tmp/target");
583        assert_eq!(json["message_format"], json!(["json"]));
584        assert_eq!(json["verbose"], 1);
585        assert_eq!(json["color"], "auto");
586        assert_eq!(json["frozen"], true);
587        assert_eq!(json["locked"], true);
588        assert_eq!(json["offline"], true);
589        assert_eq!(json["config"], json!(["config.toml"]));
590        assert_eq!(json["unstable_flags"], json!(["flag1"]));
591        assert_eq!(json["timings"], Value::Null);
592
593        let deserialized: Watch = serde_json::from_value(json).unwrap();
594
595        assert_eq!(deserialized.invoke_address, watch.invoke_address);
596        assert_eq!(deserialized.invoke_port, watch.invoke_port);
597        assert_eq!(
598            deserialized.env_options.env_file,
599            watch.env_options.env_file
600        );
601        assert_eq!(deserialized.env_options.env_var, watch.env_options.env_var);
602        assert_eq!(
603            deserialized.tls_options.tls_cert,
604            watch.tls_options.tls_cert
605        );
606        assert_eq!(deserialized.tls_options.tls_key, watch.tls_options.tls_key);
607        assert_eq!(deserialized.tls_options.tls_ca, watch.tls_options.tls_ca);
608        assert_eq!(
609            deserialized.cargo_opts.common.features,
610            watch.cargo_opts.common.features
611        );
612        assert_eq!(
613            deserialized.cargo_opts.common.no_default_features,
614            watch.cargo_opts.common.no_default_features
615        );
616        assert_eq!(
617            deserialized.cargo_opts.common.target,
618            watch.cargo_opts.common.target
619        );
620        assert_eq!(
621            deserialized.cargo_opts.common.target_dir,
622            watch.cargo_opts.common.target_dir
623        );
624        assert_eq!(
625            deserialized.cargo_opts.common.message_format,
626            watch.cargo_opts.common.message_format
627        );
628        assert_eq!(
629            deserialized.cargo_opts.common.verbose,
630            watch.cargo_opts.common.verbose
631        );
632        assert_eq!(
633            deserialized.cargo_opts.common.color,
634            watch.cargo_opts.common.color
635        );
636        assert_eq!(
637            deserialized.cargo_opts.common.frozen,
638            watch.cargo_opts.common.frozen
639        );
640        assert_eq!(
641            deserialized.cargo_opts.common.locked,
642            watch.cargo_opts.common.locked
643        );
644        assert_eq!(
645            deserialized.cargo_opts.common.offline,
646            watch.cargo_opts.common.offline
647        );
648        assert_eq!(
649            deserialized.cargo_opts.common.config,
650            watch.cargo_opts.common.config
651        );
652        assert_eq!(
653            deserialized.cargo_opts.common.unstable_flags,
654            watch.cargo_opts.common.unstable_flags
655        );
656        assert_eq!(
657            deserialized.cargo_opts.common.timings,
658            watch.cargo_opts.common.timings
659        );
660    }
661}