fakecloud_core/
protocol.rs1use bytes::Bytes;
2use http::HeaderMap;
3use std::collections::HashMap;
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum AwsProtocol {
8 Query,
11 Json,
14 Rest,
17}
18
19const REST_SERVICES: &[&str] = &["s3"];
21
22#[derive(Debug)]
24pub struct DetectedRequest {
25 pub service: String,
26 pub action: String,
27 pub protocol: AwsProtocol,
28}
29
30pub fn detect_service(
32 headers: &HeaderMap,
33 query_params: &HashMap<String, String>,
34 body: &Bytes,
35) -> Option<DetectedRequest> {
36 if let Some(target) = headers.get("x-amz-target").and_then(|v| v.to_str().ok()) {
38 return parse_amz_target(target);
39 }
40
41 if let Some(action) = query_params.get("Action") {
43 let service =
44 extract_service_from_auth(headers).or_else(|| infer_service_from_action(action));
45 if let Some(service) = service {
46 return Some(DetectedRequest {
47 service,
48 action: action.clone(),
49 protocol: AwsProtocol::Query,
50 });
51 }
52 }
53
54 {
56 let form_params = decode_form_urlencoded(body);
57
58 if let Some(action) = form_params.get("Action") {
59 let service =
60 extract_service_from_auth(headers).or_else(|| infer_service_from_action(action));
61 if let Some(service) = service {
62 return Some(DetectedRequest {
63 service,
64 action: action.clone(),
65 protocol: AwsProtocol::Query,
66 });
67 }
68 }
69 }
70
71 if let Some(service) = extract_service_from_auth(headers) {
73 if REST_SERVICES.contains(&service.as_str()) {
74 return Some(DetectedRequest {
75 service,
76 action: String::new(), protocol: AwsProtocol::Rest,
78 });
79 }
80 }
81
82 if let Some(credential) = query_params.get("X-Amz-Credential") {
84 let parts: Vec<&str> = credential.split('/').collect();
86 if parts.len() >= 4 {
87 let service = parts[3].to_string();
88 if REST_SERVICES.contains(&service.as_str()) {
89 return Some(DetectedRequest {
90 service,
91 action: String::new(),
92 protocol: AwsProtocol::Rest,
93 });
94 }
95 }
96 }
97
98 if query_params.contains_key("AWSAccessKeyId") && query_params.contains_key("Signature") {
100 return Some(DetectedRequest {
102 service: "s3".to_string(),
103 action: String::new(),
104 protocol: AwsProtocol::Rest,
105 });
106 }
107
108 None
109}
110
111fn parse_amz_target(target: &str) -> Option<DetectedRequest> {
114 let (prefix, action) = target.rsplit_once('.')?;
115
116 let service = match prefix {
117 "AWSEvents" => "events",
118 "AmazonSSM" => "ssm",
119 "AmazonSQS" => "sqs",
120 "AmazonSNS" => "sns",
121 "DynamoDB_20120810" => "dynamodb",
122 "Logs_20140328" => "logs",
123 s if s.starts_with("secretsmanager") => "secretsmanager",
124 s if s.starts_with("TrentService") => "kms",
125 _ => return None,
126 };
127
128 Some(DetectedRequest {
129 service: service.to_string(),
130 action: action.to_string(),
131 protocol: AwsProtocol::Json,
132 })
133}
134
135fn infer_service_from_action(action: &str) -> Option<String> {
139 match action {
140 "AssumeRole"
141 | "AssumeRoleWithSAML"
142 | "AssumeRoleWithWebIdentity"
143 | "GetCallerIdentity"
144 | "GetSessionToken"
145 | "GetFederationToken"
146 | "GetAccessKeyInfo"
147 | "DecodeAuthorizationMessage" => Some("sts".to_string()),
148 "CreateUser" | "DeleteUser" | "GetUser" | "ListUsers" | "CreateRole" | "DeleteRole"
149 | "GetRole" | "ListRoles" | "CreatePolicy" | "DeletePolicy" | "GetPolicy"
150 | "ListPolicies" | "AttachRolePolicy" | "DetachRolePolicy" | "CreateAccessKey"
151 | "DeleteAccessKey" | "ListAccessKeys" | "ListRolePolicies" => Some("iam".to_string()),
152 _ => None,
153 }
154}
155
156fn extract_service_from_auth(headers: &HeaderMap) -> Option<String> {
158 let auth = headers.get("authorization")?.to_str().ok()?;
159 let info = fakecloud_aws::sigv4::parse_sigv4(auth)?;
160 Some(info.service)
161}
162
163pub fn parse_query_body(body: &Bytes) -> HashMap<String, String> {
165 decode_form_urlencoded(body)
166}
167
168fn decode_form_urlencoded(input: &[u8]) -> HashMap<String, String> {
169 let s = std::str::from_utf8(input).unwrap_or("");
170 let mut result = HashMap::new();
171 for pair in s.split('&') {
172 if pair.is_empty() {
173 continue;
174 }
175 let (key, value) = match pair.find('=') {
176 Some(pos) => (&pair[..pos], &pair[pos + 1..]),
177 None => (pair, ""),
178 };
179 result.insert(url_decode(key), url_decode(value));
180 }
181 result
182}
183
184fn url_decode(input: &str) -> String {
185 let mut result = String::with_capacity(input.len());
186 let mut bytes = input.bytes();
187 while let Some(b) = bytes.next() {
188 match b {
189 b'+' => result.push(' '),
190 b'%' => {
191 let high = bytes.next().and_then(from_hex);
192 let low = bytes.next().and_then(from_hex);
193 if let (Some(h), Some(l)) = (high, low) {
194 result.push((h << 4 | l) as char);
195 }
196 }
197 _ => result.push(b as char),
198 }
199 }
200 result
201}
202
203fn from_hex(b: u8) -> Option<u8> {
204 match b {
205 b'0'..=b'9' => Some(b - b'0'),
206 b'a'..=b'f' => Some(b - b'a' + 10),
207 b'A'..=b'F' => Some(b - b'A' + 10),
208 _ => None,
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215
216 #[test]
217 fn parse_amz_target_events() {
218 let result = parse_amz_target("AWSEvents.PutEvents").unwrap();
219 assert_eq!(result.service, "events");
220 assert_eq!(result.action, "PutEvents");
221 assert_eq!(result.protocol, AwsProtocol::Json);
222 }
223
224 #[test]
225 fn parse_amz_target_ssm() {
226 let result = parse_amz_target("AmazonSSM.GetParameter").unwrap();
227 assert_eq!(result.service, "ssm");
228 assert_eq!(result.action, "GetParameter");
229 }
230
231 #[test]
232 fn parse_query_body_basic() {
233 let body = Bytes::from(
234 "Action=SendMessage&QueueUrl=http%3A%2F%2Flocalhost%3A4566%2Fqueue&MessageBody=hello",
235 );
236 let params = parse_query_body(&body);
237 assert_eq!(params.get("Action").unwrap(), "SendMessage");
238 assert_eq!(params.get("MessageBody").unwrap(), "hello");
239 }
240}