fakecloud_core/
protocol.rs1use bytes::Bytes;
2use http::HeaderMap;
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum AwsProtocol {
8 Query,
11 Json,
14 Rest,
17 RestJson,
20}
21
22const REST_XML_SERVICES: &[&str] = &["s3"];
24
25const REST_JSON_SERVICES: &[&str] = &["lambda", "ses", "apigateway"];
27
28#[derive(Debug)]
30pub struct DetectedRequest {
31 pub service: String,
32 pub action: String,
33 pub protocol: AwsProtocol,
34}
35
36pub fn detect_service(
38 headers: &HeaderMap,
39 query_params: &HashMap<String, String>,
40 body: &Bytes,
41) -> Option<DetectedRequest> {
42 if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
44 return parse_amz_target(target);
45 }
46
47 if let Some(action) = query_params.get("Action") {
49 let service =
50 extract_service_from_auth(headers).or_else(|| infer_service_from_action(action));
51 if let Some(service) = service {
52 return Some(DetectedRequest {
53 service,
54 action: action.clone(),
55 protocol: AwsProtocol::Query,
56 });
57 }
58 }
59
60 {
62 let form_params = decode_form_urlencoded(body);
63
64 if let Some(action) = form_params.get("Action") {
65 let service =
66 extract_service_from_auth(headers).or_else(|| infer_service_from_action(action));
67 if let Some(service) = service {
68 return Some(DetectedRequest {
69 service,
70 action: action.clone(),
71 protocol: AwsProtocol::Query,
72 });
73 }
74 }
75 }
76
77 if let Some(service) = extract_service_from_auth(headers) {
79 if let Some(protocol) = rest_protocol_for(&service) {
80 return Some(DetectedRequest {
81 service,
82 action: String::new(), protocol,
84 });
85 }
86 }
87
88 if let Some(credential) = query_params.get("X-Amz-Credential") {
90 let parts: Vec<&str> = credential.split('/').collect();
92 if parts.len() >= 4 {
93 let service = parts[3].to_string();
94 if let Some(protocol) = rest_protocol_for(&service) {
95 return Some(DetectedRequest {
96 service,
97 action: String::new(),
98 protocol,
99 });
100 }
101 }
102 }
103
104 if query_params.contains_key("AWSAccessKeyId")
108 && query_params.contains_key("Signature")
109 && query_params.contains_key("Expires")
110 {
111 return Some(DetectedRequest {
112 service: "s3".to_string(),
113 action: String::new(),
114 protocol: AwsProtocol::Rest,
115 });
116 }
117
118 None
119}
120
121fn parse_amz_target(target: &str) -> Option<DetectedRequest> {
124 let (prefix, action) = target.rsplit_once('.')?;
125
126 let service = match prefix {
127 "AWSEvents" => "events",
128 "AmazonSSM" => "ssm",
129 "AmazonSQS" => "sqs",
130 "AmazonSNS" => "sns",
131 "DynamoDB_20120810" => "dynamodb",
132 "Logs_20140328" => "logs",
133 s if s.starts_with("secretsmanager") => "secretsmanager",
134 s if s.starts_with("TrentService") => "kms",
135 s if s.starts_with("AWSCognitoIdentityProviderService") => "cognito-idp",
136 s if s.starts_with("Kinesis_20131202") => "kinesis",
137 s if s.starts_with("AWSStepFunctions") => "states",
138 _ => return None,
139 };
140
141 Some(DetectedRequest {
142 service: service.to_string(),
143 action: action.to_string(),
144 protocol: AwsProtocol::Json,
145 })
146}
147
148fn rest_protocol_for(service: &str) -> Option<AwsProtocol> {
150 if REST_XML_SERVICES.contains(&service) {
151 Some(AwsProtocol::Rest)
152 } else if REST_JSON_SERVICES.contains(&service) {
153 Some(AwsProtocol::RestJson)
154 } else {
155 None
156 }
157}
158
159fn infer_service_from_action(action: &str) -> Option<String> {
163 match action {
164 "AssumeRole"
165 | "AssumeRoleWithSAML"
166 | "AssumeRoleWithWebIdentity"
167 | "GetCallerIdentity"
168 | "GetSessionToken"
169 | "GetFederationToken"
170 | "GetAccessKeyInfo"
171 | "DecodeAuthorizationMessage" => Some("sts".to_string()),
172 "CreateUser" | "DeleteUser" | "GetUser" | "ListUsers" | "CreateRole" | "DeleteRole"
173 | "GetRole" | "ListRoles" | "CreatePolicy" | "DeletePolicy" | "GetPolicy"
174 | "ListPolicies" | "AttachRolePolicy" | "DetachRolePolicy" | "CreateAccessKey"
175 | "DeleteAccessKey" | "ListAccessKeys" | "ListRolePolicies" => Some("iam".to_string()),
176 "VerifyEmailIdentity"
178 | "VerifyDomainIdentity"
179 | "VerifyDomainDkim"
180 | "ListIdentities"
181 | "GetIdentityVerificationAttributes"
182 | "GetIdentityDkimAttributes"
183 | "DeleteIdentity"
184 | "SetIdentityDkimEnabled"
185 | "SetIdentityNotificationTopic"
186 | "SetIdentityFeedbackForwardingEnabled"
187 | "GetIdentityNotificationAttributes"
188 | "GetIdentityMailFromDomainAttributes"
189 | "SetIdentityMailFromDomain"
190 | "SendEmail"
191 | "SendRawEmail"
192 | "SendTemplatedEmail"
193 | "SendBulkTemplatedEmail"
194 | "CreateTemplate"
195 | "GetTemplate"
196 | "ListTemplates"
197 | "DeleteTemplate"
198 | "UpdateTemplate"
199 | "CreateConfigurationSet"
200 | "DeleteConfigurationSet"
201 | "DescribeConfigurationSet"
202 | "ListConfigurationSets"
203 | "CreateConfigurationSetEventDestination"
204 | "UpdateConfigurationSetEventDestination"
205 | "DeleteConfigurationSetEventDestination"
206 | "GetSendQuota"
207 | "GetSendStatistics"
208 | "GetAccountSendingEnabled"
209 | "CreateReceiptRuleSet"
210 | "DeleteReceiptRuleSet"
211 | "DescribeReceiptRuleSet"
212 | "ListReceiptRuleSets"
213 | "CloneReceiptRuleSet"
214 | "SetActiveReceiptRuleSet"
215 | "ReorderReceiptRuleSet"
216 | "CreateReceiptRule"
217 | "DeleteReceiptRule"
218 | "DescribeReceiptRule"
219 | "UpdateReceiptRule"
220 | "CreateReceiptFilter"
221 | "DeleteReceiptFilter"
222 | "ListReceiptFilters" => Some("ses".to_string()),
223 _ => None,
224 }
225}
226
227fn extract_service_from_auth(headers: &HeaderMap) -> Option<String> {
229 let auth = headers.get("authorization")?.to_str().ok()?;
230 let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
231 Some(info.service)
232}
233
234pub fn parse_query_body(body: &Bytes) -> HashMap<String, String> {
236 decode_form_urlencoded(body)
237}
238
239fn decode_form_urlencoded(input: &[u8]) -> HashMap<String, String> {
240 let s = std::str::from_utf8(input).unwrap_or("");
241 let mut result = HashMap::new();
242 for pair in s.split('&') {
243 if pair.is_empty() {
244 continue;
245 }
246 let (key, value) = match pair.find('=') {
247 Some(pos) => (&pair[..pos], &pair[pos + 1..]),
248 None => (pair, ""),
249 };
250 result.insert(url_decode(key), url_decode(value));
251 }
252 result
253}
254
255fn url_decode(input: &str) -> String {
256 let mut result = String::with_capacity(input.len());
257 let mut bytes = input.bytes();
258 while let Some(b) = bytes.next() {
259 match b {
260 b'+' => result.push(' '),
261 b'%' => {
262 let high = bytes.next().and_then(from_hex);
263 let low = bytes.next().and_then(from_hex);
264 if let (Some(h), Some(l)) = (high, low) {
265 result.push((h << 4 | l) as char);
266 }
267 }
268 _ => result.push(b as char),
269 }
270 }
271 result
272}
273
274fn from_hex(b: u8) -> Option<u8> {
275 match b {
276 b'0'..=b'9' => Some(b - b'0'),
277 b'a'..=b'f' => Some(b - b'a' + 10),
278 b'A'..=b'F' => Some(b - b'A' + 10),
279 _ => None,
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 #[test]
288 fn parse_amz_target_events() {
289 let result = parse_amz_target("AWSEvents.PutEvents").unwrap();
290 assert_eq!(result.service, "events");
291 assert_eq!(result.action, "PutEvents");
292 assert_eq!(result.protocol, AwsProtocol::Json);
293 }
294
295 #[test]
296 fn parse_amz_target_ssm() {
297 let result = parse_amz_target("AmazonSSM.GetParameter").unwrap();
298 assert_eq!(result.service, "ssm");
299 assert_eq!(result.action, "GetParameter");
300 }
301
302 #[test]
303 fn parse_amz_target_kinesis() {
304 let result = parse_amz_target("Kinesis_20131202.ListStreams").unwrap();
305 assert_eq!(result.service, "kinesis");
306 assert_eq!(result.action, "ListStreams");
307 assert_eq!(result.protocol, AwsProtocol::Json);
308 }
309
310 #[test]
311 fn parse_query_body_basic() {
312 let body = Bytes::from(
313 "Action=SendMessage&QueueUrl=http%3A%2F%2Flocalhost%3A4566%2Fqueue&MessageBody=hello",
314 );
315 let params = parse_query_body(&body);
316 assert_eq!(params.get("Action").unwrap(), "SendMessage");
317 assert_eq!(params.get("MessageBody").unwrap(), "hello");
318 }
319}