1use cargo_options::Run;
2use clap::Args;
3use matchit::{InsertError, MatchError, Router};
4use serde::{
5 Deserialize, Serialize,
6 de::{Error, Visitor},
7 ser::SerializeSeq,
8};
9use serde_json::{Value, json};
10use std::{collections::HashMap, path::PathBuf};
11
12use crate::{
13 cargo::{count_common_options, serialize_common_options},
14 env::{EnvOptions, Environment},
15 error::MetadataError,
16 lambda::Timeout,
17};
18
19use cargo_lambda_remote::tls::TlsOptions;
20
21#[cfg(windows)]
22const DEFAULT_INVOKE_ADDRESS: &str = "127.0.0.1";
23
24#[cfg(not(windows))]
25const DEFAULT_INVOKE_ADDRESS: &str = "::";
26
27const DEFAULT_INVOKE_PORT: u16 = 9000;
28
29#[derive(Args, Clone, Debug, Default, Deserialize)]
30#[command(
31 name = "watch",
32 visible_alias = "start",
33 after_help = "Full command documentation: https://www.cargo-lambda.info/commands/watch.html"
34)]
35pub struct Watch {
36 #[arg(long, visible_alias = "no-reload")]
38 #[serde(default)]
39 pub ignore_changes: bool,
40
41 #[arg(long)]
44 #[serde(default)]
45 pub only_lambda_apis: bool,
46
47 #[arg(short = 'a', long, default_value = DEFAULT_INVOKE_ADDRESS)]
48 #[serde(default = "default_invoke_address")]
49 pub invoke_address: String,
51
52 #[arg(short = 'P', long, default_value_t = DEFAULT_INVOKE_PORT)]
54 #[serde(default = "default_invoke_port")]
55 pub invoke_port: u16,
56
57 #[arg(long)]
59 #[serde(default)]
60 pub print_traces: bool,
61
62 #[arg(long, short)]
64 #[serde(default)]
65 pub wait: bool,
66
67 #[arg(long)]
69 #[serde(default)]
70 pub disable_cors: bool,
71
72 #[arg(long)]
74 #[serde(default)]
75 pub timeout: Option<Timeout>,
76
77 #[command(flatten)]
78 #[serde(flatten)]
79 pub cargo_opts: Run,
80
81 #[command(flatten)]
82 #[serde(flatten)]
83 pub env_options: EnvOptions,
84
85 #[command(flatten)]
86 #[serde(flatten)]
87 pub tls_options: TlsOptions,
88
89 #[arg(skip)]
90 #[serde(default)]
91 pub router: Option<FunctionRouter>,
92}
93
94impl Watch {
95 pub fn manifest_path(&self) -> PathBuf {
96 self.cargo_opts
97 .manifest_path
98 .clone()
99 .unwrap_or_else(|| "Cargo.toml".into())
100 }
101
102 pub fn package(&self) -> Option<String> {
105 if self.cargo_opts.packages.len() > 1 {
106 return None;
107 }
108 self.cargo_opts.packages.first().map(|s| s.to_string())
109 }
110
111 pub fn lambda_environment(
112 &self,
113 base: &HashMap<String, String>,
114 ) -> Result<Environment, MetadataError> {
115 self.env_options.lambda_environment(base)
116 }
117}
118
119impl Serialize for Watch {
120 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
121 where
122 S: serde::Serializer,
123 {
124 use serde::ser::SerializeStruct;
125
126 let field_count = self.ignore_changes as usize
128 + self.only_lambda_apis as usize
129 + !self.invoke_address.is_empty() as usize
130 + (self.invoke_port != 0) as usize
131 + self.print_traces as usize
132 + self.wait as usize
133 + self.disable_cors as usize
134 + self.timeout.is_some() as usize
135 + self.router.is_some() as usize
136 + self.cargo_opts.manifest_path.is_some() as usize
137 + self.cargo_opts.release as usize
138 + self.cargo_opts.ignore_rust_version as usize
139 + self.cargo_opts.unit_graph as usize
140 + !self.cargo_opts.packages.is_empty() as usize
141 + !self.cargo_opts.bin.is_empty() as usize
142 + !self.cargo_opts.example.is_empty() as usize
143 + !self.cargo_opts.args.is_empty() as usize
144 + count_common_options(&self.cargo_opts.common)
145 + self.env_options.count_fields()
146 + self.tls_options.count_fields();
147
148 let mut state = serializer.serialize_struct("Watch", field_count)?;
149
150 if self.ignore_changes {
152 state.serialize_field("ignore_changes", &true)?;
153 }
154 if self.only_lambda_apis {
155 state.serialize_field("only_lambda_apis", &true)?;
156 }
157 if !self.invoke_address.is_empty() {
158 state.serialize_field("invoke_address", &self.invoke_address)?;
159 }
160 if self.invoke_port != 0 {
161 state.serialize_field("invoke_port", &self.invoke_port)?;
162 }
163 if self.print_traces {
164 state.serialize_field("print_traces", &true)?;
165 }
166 if self.wait {
167 state.serialize_field("wait", &true)?;
168 }
169 if self.disable_cors {
170 state.serialize_field("disable_cors", &true)?;
171 }
172
173 if let Some(timeout) = &self.timeout {
175 state.serialize_field("timeout", timeout)?;
176 }
177 if let Some(router) = &self.router {
178 state.serialize_field("router", router)?;
179 }
180
181 self.env_options.serialize_fields::<S>(&mut state)?;
183 self.tls_options.serialize_fields::<S>(&mut state)?;
184
185 if let Some(manifest_path) = &self.cargo_opts.manifest_path {
186 state.serialize_field("manifest_path", manifest_path)?;
187 }
188 if self.cargo_opts.release {
189 state.serialize_field("release", &true)?;
190 }
191 if self.cargo_opts.ignore_rust_version {
192 state.serialize_field("ignore_rust_version", &true)?;
193 }
194 if self.cargo_opts.unit_graph {
195 state.serialize_field("unit_graph", &true)?;
196 }
197 if !self.cargo_opts.packages.is_empty() {
198 state.serialize_field("packages", &self.cargo_opts.packages)?;
199 }
200 if !self.cargo_opts.bin.is_empty() {
201 state.serialize_field("bin", &self.cargo_opts.bin)?;
202 }
203 if !self.cargo_opts.example.is_empty() {
204 state.serialize_field("example", &self.cargo_opts.example)?;
205 }
206 if !self.cargo_opts.args.is_empty() {
207 state.serialize_field("args", &self.cargo_opts.args)?;
208 }
209 serialize_common_options::<S>(&mut state, &self.cargo_opts.common)?;
210
211 state.end()
212 }
213}
214
215fn default_invoke_address() -> String {
216 DEFAULT_INVOKE_ADDRESS.to_string()
217}
218
219fn default_invoke_port() -> u16 {
220 DEFAULT_INVOKE_PORT
221}
222
223#[derive(Clone, Debug, Default, Deserialize, Serialize)]
224pub struct WatchConfig {
225 pub router: Option<FunctionRouter>,
226}
227
228#[derive(Clone, Debug, Default)]
229pub struct FunctionRouter {
230 inner: Router<FunctionRoutes>,
231 pub(crate) raw: Vec<(String, FunctionRoutes)>,
232}
233
234impl FunctionRouter {
235 pub fn at(
236 &self,
237 path: &str,
238 method: &str,
239 ) -> Result<(String, HashMap<String, String>), MatchError> {
240 let matched = self.inner.at(path)?;
241 let function = matched.value.at(method).ok_or(MatchError::NotFound)?;
242
243 let params = matched
244 .params
245 .iter()
246 .map(|(k, v)| (k.to_string(), v.to_string()))
247 .collect();
248
249 Ok((function.to_string(), params))
250 }
251
252 pub fn insert(&mut self, path: &str, routes: FunctionRoutes) -> Result<(), InsertError> {
253 self.inner.insert(path, routes)
254 }
255}
256
257#[derive(Clone, Debug, PartialEq)]
258pub enum FunctionRoutes {
259 Single(String),
260 Multiple(HashMap<String, String>),
261}
262
263impl FunctionRoutes {
264 pub fn at(&self, method: &str) -> Option<&str> {
265 match self {
266 FunctionRoutes::Single(function) => Some(function),
267 FunctionRoutes::Multiple(routes) => routes.get(method).map(|s| s.as_str()),
268 }
269 }
270}
271
272struct FunctionRouterVisitor;
273
274impl<'de> Visitor<'de> for FunctionRouterVisitor {
275 type Value = FunctionRouter;
276
277 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
278 formatter.write_str("a map or sequence of function routes")
279 }
280
281 fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
282 where
283 A: serde::de::MapAccess<'de>,
284 {
285 let routes: HashMap<String, FunctionRoutes> =
286 Deserialize::deserialize(serde::de::value::MapAccessDeserializer::new(map))?;
287 let mut inner = Router::new();
288
289 for (path, route) in &routes {
290 inner.insert(path, route.clone()).map_err(|e| {
291 serde::de::Error::custom(format!("Failed to insert route {path}: {e}"))
292 })?;
293 }
294
295 let raw: Vec<(String, FunctionRoutes)> = routes.into_iter().collect();
296 Ok(FunctionRouter { inner, raw })
297 }
298
299 fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
300 where
301 A: serde::de::SeqAccess<'de>,
302 {
303 let raw: Vec<(String, FunctionRoutes)> =
304 Deserialize::deserialize(serde::de::value::SeqAccessDeserializer::new(seq))?;
305 let mut inner = Router::new();
306
307 for (path, route) in &raw {
308 inner.insert(path, route.clone()).map_err(|e| {
309 serde::de::Error::custom(format!("Failed to insert route {path}: {e}"))
310 })?;
311 }
312
313 Ok(FunctionRouter { inner, raw })
314 }
315}
316
317impl<'de> Deserialize<'de> for FunctionRouter {
318 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
319 where
320 D: serde::Deserializer<'de>,
321 {
322 deserializer.deserialize_any(FunctionRouterVisitor)
323 }
324}
325
326impl Serialize for FunctionRouter {
327 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
328 where
329 S: serde::Serializer,
330 {
331 self.raw.serialize(serializer)
332 }
333}
334
335impl<'de> Deserialize<'de> for FunctionRoutes {
336 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
337 where
338 D: serde::Deserializer<'de>,
339 {
340 let value = Value::deserialize(deserializer)?;
341 match value {
342 Value::String(s) => Ok(FunctionRoutes::Single(s)),
343 Value::Array(arr) => {
344 let mut routes = HashMap::new();
345 for item in arr {
346 let obj = item.as_object().ok_or_else(|| {
347 Error::custom("Array items must be objects with method and function fields")
348 })?;
349
350 let method = obj
351 .get("method")
352 .and_then(|m| m.as_str())
353 .ok_or_else(|| Error::custom("Missing or invalid method field"))?;
354
355 let function = obj
356 .get("function")
357 .and_then(|f| f.as_str())
358 .ok_or_else(|| Error::custom("Missing or invalid function field"))?;
359
360 routes.insert(method.to_string(), function.to_string());
361 }
362 Ok(FunctionRoutes::Multiple(routes))
363 }
364 _ => Err(Error::custom(
365 "Function routes must be either a string or an array of objects with method and function fields",
366 )),
367 }
368 }
369}
370
371impl Serialize for FunctionRoutes {
372 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
373 where
374 S: serde::Serializer,
375 {
376 match self {
377 FunctionRoutes::Single(function) => function.serialize(serializer),
378 FunctionRoutes::Multiple(routes) => {
379 let mut seq = serializer.serialize_seq(Some(routes.len()))?;
380 for (method, function) in routes {
381 let mut map = serde_json::Map::new();
382 map.insert("method".to_string(), json!(method));
383 map.insert("function".to_string(), json!(function));
384 seq.serialize_element(&Value::Object(map))?;
385 }
386 seq.end()
387 }
388 }
389 }
390}
391
392#[cfg(test)]
393mod tests {
394
395 use cargo_options::CommonOptions;
396 use serde_json::{Value, json};
397 use std::path::PathBuf;
398
399 use super::*;
400
401 #[test]
402 fn test_router_deserialize() {
403 let router: FunctionRouter = toml::from_str(
404 r#"
405 "/api/v1/users" = [
406 { function = "get_user", method = "GET" },
407 { function = "create_user", method = "POST" }
408 ]
409 "/api/v1/all_methods" = "all_methods"
410 "#,
411 )
412 .unwrap();
413
414 assert_eq!(
415 router.inner.at("/api/v1/users").unwrap().value,
416 &FunctionRoutes::Multiple(HashMap::from([
417 ("GET".to_string(), "get_user".to_string()),
418 ("POST".to_string(), "create_user".to_string()),
419 ]))
420 );
421
422 assert_eq!(
423 router.inner.at("/api/v1/all_methods").unwrap().value,
424 &FunctionRoutes::Single("all_methods".to_string())
425 );
426 }
427
428 #[test]
429 fn test_router_get() {
430 let router = FunctionRouter::default();
431 assert_eq!(router.at("/api/v1/users", "GET"), Err(MatchError::NotFound));
432
433 let mut inner = Router::new();
434 inner
435 .insert(
436 "/api/v1/users",
437 FunctionRoutes::Single("user_handler".to_string()),
438 )
439 .unwrap();
440 let router = FunctionRouter {
441 inner,
442 ..Default::default()
443 };
444 assert_eq!(
445 router.at("/api/v1/users", "GET"),
446 Ok(("user_handler".to_string(), HashMap::new()))
447 );
448 assert_eq!(
449 router.at("/api/v1/users", "POST"),
450 Ok(("user_handler".to_string(), HashMap::new()))
451 );
452
453 let mut inner = Router::new();
454 inner
455 .insert(
456 "/api/v1/users",
457 FunctionRoutes::Multiple(HashMap::from([
458 ("GET".to_string(), "get_user".to_string()),
459 ("POST".to_string(), "create_user".to_string()),
460 ])),
461 )
462 .unwrap();
463 let router = FunctionRouter {
464 inner,
465 ..Default::default()
466 };
467 assert_eq!(
468 router.at("/api/v1/users", "GET"),
469 Ok(("get_user".to_string(), HashMap::new()))
470 );
471 assert_eq!(
472 router.at("/api/v1/users", "POST"),
473 Ok(("create_user".to_string(), HashMap::new()))
474 );
475 assert_eq!(router.at("/api/v1/users", "PUT"), Err(MatchError::NotFound));
476
477 let mut inner = Router::new();
478 inner
479 .insert(
480 "/api/v1/users/{id}",
481 FunctionRoutes::Single("user_handler".to_string()),
482 )
483 .unwrap();
484 let router = FunctionRouter {
485 inner,
486 ..Default::default()
487 };
488
489 let (function, params) = router.at("/api/v1/users/1", "GET").unwrap();
490 assert_eq!(function, "user_handler");
491 assert_eq!(params, HashMap::from([("id".to_string(), "1".to_string())]));
492 }
493
494 #[test]
495 fn test_router_serialize() {
496 let config = r#"
497 "/api/v1/users" = [
498 { function = "get_user", method = "GET" },
499 { function = "create_user", method = "POST" }
500 ]
501 "/api/v1/all_methods" = "all_methods"
502 "#;
503 let router: FunctionRouter = toml::from_str(config).unwrap();
504
505 let json = serde_json::to_value(&router).unwrap();
506
507 let new_router: FunctionRouter = serde_json::from_value(json).unwrap();
508 assert_eq!(new_router.raw, router.raw);
509
510 assert_eq!(
511 new_router.inner.at("/api/v1/users").unwrap().value,
512 &FunctionRoutes::Multiple(HashMap::from([
513 ("GET".to_string(), "get_user".to_string()),
514 ("POST".to_string(), "create_user".to_string()),
515 ]))
516 );
517
518 assert_eq!(
519 new_router.inner.at("/api/v1/all_methods").unwrap().value,
520 &FunctionRoutes::Single("all_methods".to_string())
521 );
522 }
523
524 #[test]
525 fn test_watch_serialization() {
526 let watch = Watch {
527 invoke_address: "127.0.0.1".to_string(),
528 invoke_port: 9000,
529 env_options: EnvOptions {
530 env_file: Some(PathBuf::from("/tmp/env")),
531 env_var: Some(vec!["FOO=BAR".to_string()]),
532 },
533 tls_options: TlsOptions::new(
534 Some(PathBuf::from("/tmp/cert.pem")),
535 Some(PathBuf::from("/tmp/key.pem")),
536 Some(PathBuf::from("/tmp/ca.pem")),
537 ),
538 cargo_opts: Run {
539 common: CommonOptions {
540 quiet: false,
541 jobs: None,
542 keep_going: false,
543 profile: None,
544 features: vec!["feature1".to_string()],
545 all_features: false,
546 no_default_features: true,
547 target: vec!["x86_64-unknown-linux-gnu".to_string()],
548 target_dir: Some(PathBuf::from("/tmp/target")),
549 message_format: vec!["json".to_string()],
550 verbose: 1,
551 color: Some("auto".to_string()),
552 frozen: true,
553 locked: true,
554 offline: true,
555 config: vec!["config.toml".to_string()],
556 unstable_flags: vec!["flag1".to_string()],
557 timings: None,
558 },
559 manifest_path: None,
560 release: false,
561 ignore_rust_version: false,
562 unit_graph: false,
563 packages: vec![],
564 bin: vec![],
565 example: vec![],
566 args: vec![],
567 },
568 ..Default::default()
569 };
570
571 let json = serde_json::to_value(&watch).unwrap();
572 assert_eq!(json["invoke_address"], "127.0.0.1");
573 assert_eq!(json["invoke_port"], 9000);
574 assert_eq!(json["env_file"], "/tmp/env");
575 assert_eq!(json["env_var"], json!(["FOO=BAR"]));
576 assert_eq!(json["tls_cert"], "/tmp/cert.pem");
577 assert_eq!(json["tls_key"], "/tmp/key.pem");
578 assert_eq!(json["tls_ca"], "/tmp/ca.pem");
579 assert_eq!(json["features"], json!(["feature1"]));
580 assert_eq!(json["no_default_features"], true);
581 assert_eq!(json["target"], json!(["x86_64-unknown-linux-gnu"]));
582 assert_eq!(json["target_dir"], "/tmp/target");
583 assert_eq!(json["message_format"], json!(["json"]));
584 assert_eq!(json["verbose"], 1);
585 assert_eq!(json["color"], "auto");
586 assert_eq!(json["frozen"], true);
587 assert_eq!(json["locked"], true);
588 assert_eq!(json["offline"], true);
589 assert_eq!(json["config"], json!(["config.toml"]));
590 assert_eq!(json["unstable_flags"], json!(["flag1"]));
591 assert_eq!(json["timings"], Value::Null);
592
593 let deserialized: Watch = serde_json::from_value(json).unwrap();
594
595 assert_eq!(deserialized.invoke_address, watch.invoke_address);
596 assert_eq!(deserialized.invoke_port, watch.invoke_port);
597 assert_eq!(
598 deserialized.env_options.env_file,
599 watch.env_options.env_file
600 );
601 assert_eq!(deserialized.env_options.env_var, watch.env_options.env_var);
602 assert_eq!(
603 deserialized.tls_options.tls_cert,
604 watch.tls_options.tls_cert
605 );
606 assert_eq!(deserialized.tls_options.tls_key, watch.tls_options.tls_key);
607 assert_eq!(deserialized.tls_options.tls_ca, watch.tls_options.tls_ca);
608 assert_eq!(
609 deserialized.cargo_opts.common.features,
610 watch.cargo_opts.common.features
611 );
612 assert_eq!(
613 deserialized.cargo_opts.common.no_default_features,
614 watch.cargo_opts.common.no_default_features
615 );
616 assert_eq!(
617 deserialized.cargo_opts.common.target,
618 watch.cargo_opts.common.target
619 );
620 assert_eq!(
621 deserialized.cargo_opts.common.target_dir,
622 watch.cargo_opts.common.target_dir
623 );
624 assert_eq!(
625 deserialized.cargo_opts.common.message_format,
626 watch.cargo_opts.common.message_format
627 );
628 assert_eq!(
629 deserialized.cargo_opts.common.verbose,
630 watch.cargo_opts.common.verbose
631 );
632 assert_eq!(
633 deserialized.cargo_opts.common.color,
634 watch.cargo_opts.common.color
635 );
636 assert_eq!(
637 deserialized.cargo_opts.common.frozen,
638 watch.cargo_opts.common.frozen
639 );
640 assert_eq!(
641 deserialized.cargo_opts.common.locked,
642 watch.cargo_opts.common.locked
643 );
644 assert_eq!(
645 deserialized.cargo_opts.common.offline,
646 watch.cargo_opts.common.offline
647 );
648 assert_eq!(
649 deserialized.cargo_opts.common.config,
650 watch.cargo_opts.common.config
651 );
652 assert_eq!(
653 deserialized.cargo_opts.common.unstable_flags,
654 watch.cargo_opts.common.unstable_flags
655 );
656 assert_eq!(
657 deserialized.cargo_opts.common.timings,
658 watch.cargo_opts.common.timings
659 );
660 }
661}