fakecloud_core/
protocol.rs1use bytes::Bytes;
2use http::HeaderMap;
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum AwsProtocol {
8 Query,
11 Json,
14 Rest,
17}
18
19const REST_SERVICES: &[&str] = &["s3", "lambda"];
22
23#[derive(Debug)]
25pub struct DetectedRequest {
26 pub service: String,
27 pub action: String,
28 pub protocol: AwsProtocol,
29}
30
31pub fn detect_service(
33 headers: &HeaderMap,
34 query_params: &HashMap<String, String>,
35 body: &Bytes,
36) -> Option<DetectedRequest> {
37 if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
39 return parse_amz_target(target);
40 }
41
42 if let Some(action) = query_params.get("Action") {
44 let service =
45 extract_service_from_auth(headers).or_else(|| infer_service_from_action(action));
46 if let Some(service) = service {
47 return Some(DetectedRequest {
48 service,
49 action: action.clone(),
50 protocol: AwsProtocol::Query,
51 });
52 }
53 }
54
55 {
57 let form_params = decode_form_urlencoded(body);
58
59 if let Some(action) = form_params.get("Action") {
60 let service =
61 extract_service_from_auth(headers).or_else(|| infer_service_from_action(action));
62 if let Some(service) = service {
63 return Some(DetectedRequest {
64 service,
65 action: action.clone(),
66 protocol: AwsProtocol::Query,
67 });
68 }
69 }
70 }
71
72 if let Some(service) = extract_service_from_auth(headers) {
74 if REST_SERVICES.contains(&service.as_str()) {
75 return Some(DetectedRequest {
76 service,
77 action: String::new(), protocol: AwsProtocol::Rest,
79 });
80 }
81 }
82
83 if let Some(credential) = query_params.get("X-Amz-Credential") {
85 let parts: Vec<&str> = credential.split('/').collect();
87 if parts.len() >= 4 {
88 let service = parts[3].to_string();
89 if REST_SERVICES.contains(&service.as_str()) {
90 return Some(DetectedRequest {
91 service,
92 action: String::new(),
93 protocol: AwsProtocol::Rest,
94 });
95 }
96 }
97 }
98
99 if query_params.contains_key("AWSAccessKeyId")
103 && query_params.contains_key("Signature")
104 && query_params.contains_key("Expires")
105 {
106 return Some(DetectedRequest {
107 service: "s3".to_string(),
108 action: String::new(),
109 protocol: AwsProtocol::Rest,
110 });
111 }
112
113 None
114}
115
116fn parse_amz_target(target: &str) -> Option<DetectedRequest> {
119 let (prefix, action) = target.rsplit_once('.')?;
120
121 let service = match prefix {
122 "AWSEvents" => "events",
123 "AmazonSSM" => "ssm",
124 "AmazonSQS" => "sqs",
125 "AmazonSNS" => "sns",
126 "DynamoDB_20120810" => "dynamodb",
127 "Logs_20140328" => "logs",
128 s if s.starts_with("secretsmanager") => "secretsmanager",
129 s if s.starts_with("TrentService") => "kms",
130 _ => return None,
131 };
132
133 Some(DetectedRequest {
134 service: service.to_string(),
135 action: action.to_string(),
136 protocol: AwsProtocol::Json,
137 })
138}
139
140fn infer_service_from_action(action: &str) -> Option<String> {
144 match action {
145 "AssumeRole"
146 | "AssumeRoleWithSAML"
147 | "AssumeRoleWithWebIdentity"
148 | "GetCallerIdentity"
149 | "GetSessionToken"
150 | "GetFederationToken"
151 | "GetAccessKeyInfo"
152 | "DecodeAuthorizationMessage" => Some("sts".to_string()),
153 "CreateUser" | "DeleteUser" | "GetUser" | "ListUsers" | "CreateRole" | "DeleteRole"
154 | "GetRole" | "ListRoles" | "CreatePolicy" | "DeletePolicy" | "GetPolicy"
155 | "ListPolicies" | "AttachRolePolicy" | "DetachRolePolicy" | "CreateAccessKey"
156 | "DeleteAccessKey" | "ListAccessKeys" | "ListRolePolicies" => Some("iam".to_string()),
157 _ => None,
158 }
159}
160
161fn extract_service_from_auth(headers: &HeaderMap) -> Option<String> {
163 let auth = headers.get("authorization")?.to_str().ok()?;
164 let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
165 Some(info.service)
166}
167
168pub fn parse_query_body(body: &Bytes) -> HashMap<String, String> {
170 decode_form_urlencoded(body)
171}
172
173fn decode_form_urlencoded(input: &[u8]) -> HashMap<String, String> {
174 let s = std::str::from_utf8(input).unwrap_or("");
175 let mut result = HashMap::new();
176 for pair in s.split('&') {
177 if pair.is_empty() {
178 continue;
179 }
180 let (key, value) = match pair.find('=') {
181 Some(pos) => (&pair[..pos], &pair[pos + 1..]),
182 None => (pair, ""),
183 };
184 result.insert(url_decode(key), url_decode(value));
185 }
186 result
187}
188
189fn url_decode(input: &str) -> String {
190 let mut result = String::with_capacity(input.len());
191 let mut bytes = input.bytes();
192 while let Some(b) = bytes.next() {
193 match b {
194 b'+' => result.push(' '),
195 b'%' => {
196 let high = bytes.next().and_then(from_hex);
197 let low = bytes.next().and_then(from_hex);
198 if let (Some(h), Some(l)) = (high, low) {
199 result.push((h << 4 | l) as char);
200 }
201 }
202 _ => result.push(b as char),
203 }
204 }
205 result
206}
207
208fn from_hex(b: u8) -> Option<u8> {
209 match b {
210 b'0'..=b'9' => Some(b - b'0'),
211 b'a'..=b'f' => Some(b - b'a' + 10),
212 b'A'..=b'F' => Some(b - b'A' + 10),
213 _ => None,
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[test]
222 fn parse_amz_target_events() {
223 let result = parse_amz_target("AWSEvents.PutEvents").unwrap();
224 assert_eq!(result.service, "events");
225 assert_eq!(result.action, "PutEvents");
226 assert_eq!(result.protocol, AwsProtocol::Json);
227 }
228
229 #[test]
230 fn parse_amz_target_ssm() {
231 let result = parse_amz_target("AmazonSSM.GetParameter").unwrap();
232 assert_eq!(result.service, "ssm");
233 assert_eq!(result.action, "GetParameter");
234 }
235
236 #[test]
237 fn parse_query_body_basic() {
238 let body = Bytes::from(
239 "Action=SendMessage&QueueUrl=http%3A%2F%2Flocalhost%3A4566%2Fqueue&MessageBody=hello",
240 );
241 let params = parse_query_body(&body);
242 assert_eq!(params.get("Action").unwrap(), "SendMessage");
243 assert_eq!(params.get("MessageBody").unwrap(), "hello");
244 }
245}