1use std::collections::HashMap;
2use serde_urlencoded;
3use crate::config::{ClientInfo, ConfigUpdate};
4use super::{AuthProvider, AuthResult, authenticator::GatewayAuthError};
5use hyper::http::request::Parts;
6use std::str::FromStr;
7use regex::Regex;
8use tracing::{event, Level};
9
10#[derive(Debug)]
11pub struct AppKeyAuthProvider {
12 app_key: HashMap<String, ClientInfo>,
13 app_id: HashMap<String, String>, }
15
16
17impl AuthProvider for AppKeyAuthProvider {
18 fn update_config(&mut self, update: crate::config::ConfigUpdate) {
19 match update {
20 ConfigUpdate::ClientUpdate(client) => {
21 let client_key = client.app_key.clone();
22 let app_id = client.client_id.clone();
23 if let Some(old_app_key) = self.app_id.insert(app_id, client_key.clone()) {
24 self.app_key.remove(&old_app_key);
25 }
26 self.app_key.insert(client_key, client);
27 },
28 ConfigUpdate::ClientRemove(cid) => {
29 if let Some(app_key) = self.app_id.remove(&cid) {
30 self.app_key.remove(&app_key);
31 }
32 },
33 _ => {},
34 }
35 }
36
37 fn identify_client(&self, mut head: Parts, service_id: &str) -> Result<(Parts, AuthResult), GatewayAuthError> {
38 let appkey = Self::get_app_key(&head)?;
39 let client = self.app_key.get(&appkey).ok_or(GatewayAuthError::InvalidToken)?;
40 let sla = client.services.get(service_id).ok_or(GatewayAuthError::InvalidSLA)?;
41
42 let url = head.uri.to_string();
44 let replaced = format!("/~{}/", appkey);
45 let url = url.replace(&replaced, "/");
46 head.uri = hyper::Uri::from_str(&url).unwrap();
47
48 let result = AuthResult {
49 client_id: client.client_id.clone(),
50 sla: sla.clone(),
51 };
52 return Ok((head, result));
53 }
54}
55
56
57impl AppKeyAuthProvider {
58
59 pub fn new() -> Self {
60 AppKeyAuthProvider {
61 app_key: HashMap::new(),
62 app_id: HashMap::new(),
63 }
64 }
65
66 fn get_app_key(head: &Parts) -> Result<String, GatewayAuthError> {
67 if let Some(token) = head.headers.get("X-APP-KEY") {
69 if let Ok(token_str) = token.to_str() {
70 return Ok(String::from(token_str));
71 }
72 }
73
74 if let Some(query) = head.uri.query() {
76 let query_pairs = serde_urlencoded::from_str::<Vec<(String, String)>>(query);
77 if let Ok(pairs) = query_pairs {
78 for (k, v) in pairs {
79 if k.eq("_app_key") {
80 return Ok(v);
81 }
82 }
83 } else {
84 event!(Level::DEBUG, "{:?}", query_pairs);
85 }
86 }
87
88 let pattern = Regex::new(r"^/.+?/~(.+?)/").unwrap();
90 if let Some(appkey_match) = pattern.captures(head.uri.path()) {
91 if let Some(am) = appkey_match.get(1) {
92 return Ok(String::from(am.as_str()))
93 }
94 }
95
96 Err(GatewayAuthError::TokenNotFound)
97 }
98}
99