1use base64::{Engine as _, engine::general_purpose as b64};
2use cargo_lambda_metadata::DEFAULT_PACKAGE_FUNCTION;
3use cargo_lambda_remote::{
4 RemoteConfig,
5 aws_sdk_lambda::{Client as LambdaClient, primitives::Blob},
6 tls::TlsOptions,
7};
8use clap::{Args, ValueHint};
9use miette::{IntoDiagnostic, Result, WrapErr};
10use reqwest::{Client, StatusCode};
11use serde::Serialize;
12use serde_json::{from_str, to_string_pretty, value::Value};
13use std::{
14 convert::TryFrom,
15 fs::{File, create_dir_all, read_to_string},
16 io::copy,
17 net::IpAddr,
18 path::PathBuf,
19 str::{FromStr, from_utf8},
20};
21use strum_macros::{Display, EnumString};
22use tracing::debug;
23
24mod error;
25use error::*;
26
27const EXAMPLES_URL: &str = "https://event-examples.cargo-lambda.info";
28
29const LAMBDA_RUNTIME_CLIENT_CONTEXT: &str = "lambda-runtime-client-context";
30const LAMBDA_RUNTIME_COGNITO_IDENTITY: &str = "lambda-runtime-cognito-identity";
31
32#[derive(Args, Clone, Debug)]
33#[command(
34 name = "invoke",
35 after_help = "Full command documentation: https://www.cargo-lambda.info/commands/invoke.html"
36)]
37pub struct Invoke {
38 #[cfg_attr(
39 target_os = "windows",
40 arg(short = 'a', long, default_value = "127.0.0.1")
41 )]
42 #[cfg_attr(
43 not(target_os = "windows"),
44 arg(short = 'A', long, default_value = "::1")
45 )]
46 invoke_address: String,
48
49 #[arg(short = 'P', long, default_value = "9000")]
51 invoke_port: u16,
52
53 #[arg(short = 'F', long, value_hint = ValueHint::FilePath)]
55 data_file: Option<PathBuf>,
56
57 #[arg(short = 'A', long)]
59 data_ascii: Option<String>,
60
61 #[arg(short = 'E', long)]
63 data_example: Option<String>,
64
65 #[arg(short = 'R', long)]
67 remote: bool,
68
69 #[command(flatten)]
70 remote_config: RemoteConfig,
71
72 #[arg(long)]
74 client_context_ascii: Option<String>,
75
76 #[arg(long)]
78 client_context_file: Option<PathBuf>,
79
80 #[arg(short, long, default_value_t = OutputFormat::Text)]
82 output_format: OutputFormat,
83
84 #[command(flatten)]
85 cognito: Option<CognitoIdentity>,
86
87 #[arg(long, default_value_t = false)]
89 skip_cache: bool,
90
91 #[arg(default_value = DEFAULT_PACKAGE_FUNCTION)]
93 function_name: String,
94
95 #[command(flatten)]
96 tls_options: TlsOptions,
97}
98
99#[derive(Clone, Debug, Display, EnumString)]
100#[strum(ascii_case_insensitive)]
101enum OutputFormat {
102 Text,
103 Json,
104}
105
106#[derive(Args, Clone, Debug, Serialize)]
107pub struct CognitoIdentity {
108 #[arg(long, requires = "identity-pool-id")]
110 #[serde(rename = "cognitoIdentityId")]
111 pub identity_id: Option<String>,
112 #[arg(long, requires = "identity-id")]
114 #[serde(rename = "cognitoIdentityPoolId")]
115 pub identity_pool_id: Option<String>,
116}
117
118impl CognitoIdentity {
119 fn is_valid(&self) -> bool {
120 self.identity_id.is_some() && self.identity_pool_id.is_some()
121 }
122}
123
124impl Invoke {
125 #[tracing::instrument(skip(self), target = "cargo_lambda")]
126 pub async fn run(&self) -> Result<()> {
127 tracing::trace!(options = ?self, "invoking function");
128
129 let data = if let Some(file) = &self.data_file {
130 read_to_string(file)
131 .into_diagnostic()
132 .wrap_err("error reading data file")?
133 } else if let Some(data) = &self.data_ascii {
134 data.clone()
135 } else if let Some(example) = &self.data_example {
136 let name = example_name(example);
137
138 let cache = dirs::cache_dir()
139 .map(|p| p.join("cargo-lambda").join("invoke-fixtures").join(&name));
140
141 match cache {
142 Some(cache) if !self.skip_cache && cache.exists() => {
143 tracing::debug!(?cache, "using example from cache");
144 read_to_string(cache)
145 .into_diagnostic()
146 .wrap_err("error reading data file")?
147 }
148 _ if self.skip_cache => download_example(&name, None, None).await?,
149 _ => download_example(&name, cache, None).await?,
150 }
151 } else {
152 return Err(InvokeError::MissingPayload.into());
153 };
154
155 let text = if self.remote {
156 self.invoke_remote(&data).await?
157 } else {
158 self.invoke_local(&data).await?
159 };
160
161 let text = match &self.output_format {
162 OutputFormat::Text => text,
163 OutputFormat::Json => {
164 let obj: Value = from_str(&text)
165 .into_diagnostic()
166 .wrap_err("failed to serialize response into json")?;
167
168 to_string_pretty(&obj)
169 .into_diagnostic()
170 .wrap_err("failed to format json output")?
171 }
172 };
173
174 println!("{text}");
175
176 Ok(())
177 }
178
179 async fn invoke_remote(&self, data: &str) -> Result<String> {
180 if self.function_name == DEFAULT_PACKAGE_FUNCTION {
181 return Err(InvokeError::InvalidFunctionName.into());
182 }
183
184 let client_context = self.client_context(true)?;
185
186 let sdk_config = self.remote_config.sdk_config(None).await;
187 let client = LambdaClient::new(&sdk_config);
188
189 let resp = client
190 .invoke()
191 .function_name(&self.function_name)
192 .set_qualifier(self.remote_config.alias.clone())
193 .payload(Blob::new(data.as_bytes()))
194 .set_client_context(client_context)
195 .send()
196 .await
197 .into_diagnostic()
198 .wrap_err("failed to invoke remote function")?;
199
200 if let Some(payload) = resp.payload {
201 let blob = payload.into_inner();
202 let data = from_utf8(&blob)
203 .into_diagnostic()
204 .wrap_err("failed to read response payload")?;
205
206 if resp.function_error.is_some() {
207 let err = RemoteInvokeError::try_from(data)?;
208 Err(err.into())
209 } else {
210 Ok(data.into())
211 }
212 } else {
213 Ok("OK".into())
214 }
215 }
216
217 async fn invoke_local(&self, data: &str) -> Result<String> {
218 let host = parse_invoke_ip_address(&self.invoke_address)?;
219
220 let (protocol, client) = if self.tls_options.is_secure() {
221 let tls = self.tls_options.client_config()?;
222 let client = Client::builder()
223 .use_preconfigured_tls(tls)
224 .build()
225 .into_diagnostic()?;
226
227 ("https", client)
228 } else {
229 ("http", Client::new())
230 };
231
232 let url = format!(
233 "{}://{}:{}/2015-03-31/functions/{}/invocations",
234 protocol, &host, self.invoke_port, &self.function_name
235 );
236
237 let mut req = client.post(url).body(data.to_string());
238 if let Some(identity) = &self.cognito {
239 if identity.is_valid() {
240 let ser = serde_json::to_string(&identity)
241 .into_diagnostic()
242 .wrap_err("failed to serialize Cognito's identity information")?;
243 req = req.header(LAMBDA_RUNTIME_COGNITO_IDENTITY, ser);
244 }
245 }
246 if let Some(client_context) = self.client_context(false)? {
247 req = req.header(LAMBDA_RUNTIME_CLIENT_CONTEXT, client_context);
248 }
249
250 let resp = req
251 .send()
252 .await
253 .into_diagnostic()
254 .wrap_err("error sending request to the runtime emulator")?;
255 let success = resp.status() == StatusCode::OK;
256
257 let payload = resp
258 .text()
259 .await
260 .into_diagnostic()
261 .wrap_err("error reading response body")?;
262
263 if success {
264 Ok(payload)
265 } else {
266 debug!(error = ?payload, "error received from server");
267 let err = RemoteInvokeError::try_from(payload.as_str())?;
268 Err(err.into())
269 }
270 }
271
272 fn client_context(&self, encode: bool) -> Result<Option<String>> {
273 let mut data = if let Some(file) = &self.client_context_file {
274 read_to_string(file)
275 .into_diagnostic()
276 .wrap_err("error reading client context file")?
277 } else if let Some(data) = &self.client_context_ascii {
278 data.clone()
279 } else {
280 return Ok(None);
281 };
282
283 if encode {
284 data = b64::STANDARD.encode(data)
285 }
286
287 Ok(Some(data))
288 }
289}
290
291fn example_name(example: &str) -> String {
292 let mut name = if example.starts_with("example-") {
293 example.to_string()
294 } else {
295 format!("example-{example}")
296 };
297 if !name.ends_with(".json") {
298 name.push_str(".json");
299 }
300 name
301}
302
303async fn download_example(
304 name: &str,
305 cache: Option<PathBuf>,
306 authority: Option<&str>,
307) -> Result<String> {
308 let authority = authority.unwrap_or(EXAMPLES_URL);
309 let target = format!("{authority}/{name}");
310
311 tracing::debug!(?target, "downloading remote example");
312 let response = reqwest::get(&target)
313 .await
314 .into_diagnostic()
315 .wrap_err("error dowloading example data")?;
316
317 if response.status() != StatusCode::OK {
318 Err(InvokeError::ExampleDownloadFailed(target, response).into())
319 } else {
320 let content = response
321 .text()
322 .await
323 .into_diagnostic()
324 .wrap_err("error reading example data")?;
325
326 if let Some(cache) = cache {
327 tracing::debug!(?cache, "storing example in cache");
328 create_dir_all(cache.parent().unwrap()).into_diagnostic()?;
329 let mut dest = File::create(cache).into_diagnostic()?;
330 copy(&mut content.as_bytes(), &mut dest).into_diagnostic()?;
331 }
332 Ok(content)
333 }
334}
335
336fn parse_invoke_ip_address(address: &str) -> Result<String> {
337 let invoke_address = IpAddr::from_str(address).map_err(|e| miette::miette!(e))?;
338
339 let invoke_address = match invoke_address {
340 IpAddr::V4(address) => address.to_string(),
341 IpAddr::V6(address) => format!("[{address}]"),
342 };
343
344 Ok(invoke_address)
345}
346
347#[cfg(test)]
348mod test {
349 use httpmock::MockServer;
350
351 use super::*;
352
353 #[tokio::test]
354 async fn test_download_example() {
355 let server = MockServer::start_async().await;
356
357 let mock = server.mock(|when, then| {
358 when.path("/example-apigw-request.json");
359 then.status(200)
360 .header("Content-Type", "application/json")
361 .body_from_file("../../tests/fixtures/events/example-apigw-request.json");
362 });
363
364 let data = download_example(
365 "example-apigw-request.json",
366 None,
367 Some(&format!("http://{}", server.address())),
368 )
369 .await
370 .expect("failed to download json");
371
372 mock.assert();
373 assert!(data.contains("\"path\": \"/hello/world\""));
374 }
375
376 #[tokio::test]
377 async fn test_download_example_with_cache() {
378 let server = MockServer::start_async().await;
379
380 let mock = server.mock(|when, then| {
381 when.path("/example-apigw-request.json");
382 then.status(200)
383 .header("Content-Type", "application/json")
384 .body_from_file("../../tests/fixtures/events/example-apigw-request.json");
385 });
386
387 let cache = tempfile::TempDir::new()
388 .unwrap()
389 .path()
390 .join("cargo-lambda")
391 .join("example-apigw-request.json");
392
393 let data = download_example(
394 "example-apigw-request.json",
395 Some(cache.to_path_buf()),
396 Some(&format!("http://{}", server.address())),
397 )
398 .await
399 .unwrap();
400
401 mock.assert();
402 assert!(data.contains("\"path\": \"/hello/world\""));
403 assert!(cache.exists());
404
405 let content = read_to_string(cache).unwrap();
406 assert_eq!(content, data);
407 }
408
409 #[test]
410 fn test_example_name() {
411 assert_eq!(example_name("apigw-request"), "example-apigw-request.json");
412 assert_eq!(
413 example_name("apigw-request.json"),
414 "example-apigw-request.json"
415 );
416 assert_eq!(
417 example_name("example-apigw-request"),
418 "example-apigw-request.json"
419 );
420 assert_eq!(
421 example_name("example-apigw-request.json"),
422 "example-apigw-request.json"
423 );
424 }
425}