1use crate::api::{ApiResponse, ApiState};
2use axum::{
3 extract::{Query, State},
4 response::{Html, Json},
5};
6use base64::Engine;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Serialize, Deserialize)]
12pub struct SamlSsoRequest {
13 pub idp_entity_id: String,
14 pub relay_state: Option<String>,
15 pub force_authn: Option<bool>,
16 pub is_passive: Option<bool>,
17}
18
19#[derive(Debug, Serialize, Deserialize)]
21pub struct SamlSsoResponse {
22 pub redirect_url: String,
23 pub saml_request: String,
24 pub relay_state: Option<String>,
25}
26
27#[derive(Debug, Serialize, Deserialize)]
29pub struct SamlAcsRequest {
30 #[serde(rename = "SAMLResponse")]
31 pub saml_response: String,
32 #[serde(rename = "RelayState")]
33 pub relay_state: Option<String>,
34 #[serde(rename = "SigAlg")]
35 pub sig_alg: Option<String>,
36 #[serde(rename = "Signature")]
37 pub signature: Option<String>,
38}
39
40#[derive(Debug, Serialize, Deserialize)]
42pub struct SamlMetadataResponse {
43 pub entity_id: String,
44 pub acs_url: String,
45 pub sls_url: Option<String>,
46 pub certificate: Option<String>,
47 pub name_id_format: String,
48}
49
50#[derive(Debug, Serialize, Deserialize)]
52pub struct SamlLogoutRequest {
53 pub name_id: String,
54 pub session_index: Option<String>,
55 pub idp_entity_id: String,
56}
57
58#[derive(Debug, Serialize, Deserialize)]
60pub struct SamlLogoutResponse {
61 pub redirect_url: String,
62 pub status: String,
63}
64
65pub async fn get_saml_metadata(State(_state): State<ApiState>) -> Html<String> {
67 let metadata_xml = r#"<?xml version="1.0" encoding="UTF-8"?>
68<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata"
69 entityID="https://auth.example.com">
70 <md:SPSSODescriptor AuthnRequestsSigned="true" WantAssertionsSigned="true"
71 protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
72 <md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress</md:NameIDFormat>
73 <md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST"
74 Location="https://auth.example.com/api/saml/acs"
75 index="0" />
76 <md:SingleLogoutService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect"
77 Location="https://auth.example.com/api/saml/slo" />
78 </md:SPSSODescriptor>
79</md:EntityDescriptor>"#;
80
81 Html(metadata_xml.to_string())
82}
83
84pub async fn initiate_saml_sso(
86 State(_state): State<ApiState>,
87 Json(request): Json<SamlSsoRequest>,
88) -> Json<ApiResponse<SamlSsoResponse>> {
89 let request_id = format!("saml_{}", uuid::Uuid::new_v4());
91 let issue_instant = chrono::Utc::now().to_rfc3339();
92
93 let saml_request = format!(
94 r#"<samlp:AuthnRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
95 xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
96 ID="{}"
97 Version="2.0"
98 IssueInstant="{}"
99 Destination="https://idp.example.com/sso"
100 {}
101 {}
102 AssertionConsumerServiceURL="https://auth.example.com/api/saml/acs">
103 <saml:Issuer>https://auth.example.com</saml:Issuer>
104 <samlp:NameIDPolicy Format="urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress"
105 AllowCreate="true" />
106</samlp:AuthnRequest>"#,
107 request_id,
108 issue_instant,
109 request
110 .force_authn
111 .map_or(String::new(), |fa| format!(r#"ForceAuthn="{}""#, fa)),
112 request
113 .is_passive
114 .map_or(String::new(), |ip| format!(r#"IsPassive="{}""#, ip))
115 );
116
117 let encoded_request = base64::engine::general_purpose::STANDARD.encode(&saml_request);
119
120 let mut redirect_url = format!(
122 "https://idp.example.com/sso?SAMLRequest={}",
123 urlencoding::encode(&encoded_request)
124 );
125
126 if let Some(relay_state) = &request.relay_state {
127 redirect_url.push_str(&format!("&RelayState={}", urlencoding::encode(relay_state)));
128 }
129
130 let _request_key = format!("saml_request:{}", request_id);
132 let _request_data = serde_json::json!({
133 "request_id": request_id,
134 "idp_entity_id": request.idp_entity_id,
135 "relay_state": request.relay_state,
136 "timestamp": chrono::Utc::now().timestamp()
137 });
138
139 Json(ApiResponse::success(SamlSsoResponse {
140 redirect_url,
141 saml_request: encoded_request,
142 relay_state: request.relay_state,
143 }))
144}
145
146pub async fn handle_saml_acs(
148 State(_state): State<ApiState>,
149 axum::Form(form_data): axum::Form<SamlAcsRequest>,
150) -> Json<ApiResponse<serde_json::Value>> {
151 let saml_response_xml =
153 match base64::engine::general_purpose::STANDARD.decode(&form_data.saml_response) {
154 Ok(decoded) => match String::from_utf8(decoded) {
155 Ok(xml) => xml,
156 Err(e) => {
157 return Json(ApiResponse::validation_error_typed(format!(
158 "Invalid SAML response UTF-8: {}",
159 e
160 )));
161 }
162 },
163 Err(e) => {
164 return Json(ApiResponse::validation_error_typed(format!(
165 "Invalid SAML response encoding: {}",
166 e
167 )));
168 }
169 };
170
171 if !saml_response_xml.contains("<saml:Assertion") {
173 return Json(ApiResponse::validation_error_typed(
174 "No SAML assertion found",
175 ));
176 }
177
178 let username = match extract_username_from_saml(&saml_response_xml) {
180 Ok(user) => user,
181 Err(e) => return Json(ApiResponse::error_typed("SAML_PARSE_ERROR", e)),
182 };
183
184 let attributes = match extract_attributes_from_saml(&saml_response_xml) {
185 Ok(attrs) => attrs,
186 Err(e) => return Json(ApiResponse::error_typed("SAML_PARSE_ERROR", e)),
187 };
188
189 let token_data = serde_json::json!({
191 "access_token": format!("saml_token_{}", uuid::Uuid::new_v4()),
192 "token_type": "Bearer",
193 "expires_in": 3600,
194 "user_id": username,
195 "authentication_method": "saml",
196 "attributes": attributes,
197 "relay_state": form_data.relay_state
198 });
199
200 Json(ApiResponse::success_with_message(
201 token_data,
202 "SAML authentication successful",
203 ))
204}
205
206pub async fn initiate_saml_slo(
208 State(_state): State<ApiState>,
209 Json(request): Json<SamlLogoutRequest>,
210) -> Json<ApiResponse<SamlLogoutResponse>> {
211 let logout_id = format!("logout_{}", uuid::Uuid::new_v4());
212 let issue_instant = chrono::Utc::now().to_rfc3339();
213
214 let saml_logout_request = format!(
216 r#"<samlp:LogoutRequest xmlns:samlp="urn:oasis:names:tc:SAML:2.0:protocol"
217 xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
218 ID="{}"
219 Version="2.0"
220 IssueInstant="{}"
221 Destination="https://idp.example.com/slo">
222 <saml:Issuer>https://auth.example.com</saml:Issuer>
223 <saml:NameID Format="urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress">{}</saml:NameID>
224 {}
225</samlp:LogoutRequest>"#,
226 logout_id,
227 issue_instant,
228 request.name_id,
229 request.session_index.map_or(String::new(), |si| format!(
230 r#"<samlp:SessionIndex>{}</samlp:SessionIndex>"#,
231 si
232 ))
233 );
234
235 let encoded_request = base64::engine::general_purpose::STANDARD.encode(&saml_logout_request);
236 let redirect_url = format!(
237 "https://idp.example.com/slo?SAMLRequest={}",
238 urlencoding::encode(&encoded_request)
239 );
240
241 Json(ApiResponse::success_with_message(
242 SamlLogoutResponse {
243 redirect_url,
244 status: "logout_initiated".to_string(),
245 },
246 "SAML logout initiated",
247 ))
248}
249
250pub async fn handle_saml_slo_response(
252 State(_state): State<ApiState>,
253 Query(params): Query<HashMap<String, String>>,
254) -> Json<ApiResponse<()>> {
255 let saml_response = match params.get("SAMLResponse") {
256 Some(response) => response,
257 None => {
258 return Json(ApiResponse::validation_error(
259 "Missing SAMLResponse parameter",
260 ));
261 }
262 };
263
264 let response_xml = match base64::engine::general_purpose::STANDARD.decode(saml_response) {
266 Ok(decoded) => match String::from_utf8(decoded) {
267 Ok(xml) => xml,
268 Err(e) => {
269 return Json(ApiResponse::validation_error(format!(
270 "Invalid SLO response UTF-8: {}",
271 e
272 )));
273 }
274 },
275 Err(e) => {
276 return Json(ApiResponse::validation_error(format!(
277 "Invalid SLO response encoding: {}",
278 e
279 )));
280 }
281 };
282
283 if response_xml.contains("urn:oasis:names:tc:SAML:2.0:status:Success") {
285 Json(ApiResponse::<()>::ok_with_message(
286 "SAML logout completed successfully",
287 ))
288 } else {
289 Json(ApiResponse::error(
290 "SAML_LOGOUT_FAILED",
291 "SAML logout failed",
292 ))
293 }
294}
295
296pub async fn create_saml_assertion(
298 State(_state): State<ApiState>,
299 Json(request): Json<serde_json::Value>,
300) -> Json<ApiResponse<String>> {
301 let username = match request["username"].as_str() {
302 Some(user) => user,
303 None => return Json(ApiResponse::validation_error_typed("Username required")),
304 };
305
306 let audience = match request["audience"].as_str() {
307 Some(aud) => aud,
308 None => return Json(ApiResponse::validation_error_typed("Audience required")),
309 };
310
311 let assertion_xml = format!(
313 r#"<saml:Assertion xmlns:saml="urn:oasis:names:tc:SAML:2.0:assertion"
314 ID="assertion_{}"
315 IssueInstant="{}"
316 Version="2.0">
317 <saml:Issuer>https://auth.example.com</saml:Issuer>
318 <saml:Subject>
319 <saml:NameID Format="urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress">{}@example.com</saml:NameID>
320 </saml:Subject>
321 <saml:Conditions NotBefore="{}" NotOnOrAfter="{}">
322 <saml:AudienceRestriction>
323 <saml:Audience>{}</saml:Audience>
324 </saml:AudienceRestriction>
325 </saml:Conditions>
326 <saml:AttributeStatement>
327 <saml:Attribute Name="username">
328 <saml:AttributeValue>{}</saml:AttributeValue>
329 </saml:Attribute>
330 <saml:Attribute Name="email">
331 <saml:AttributeValue>{}@example.com</saml:AttributeValue>
332 </saml:Attribute>
333 </saml:AttributeStatement>
334</saml:Assertion>"#,
335 uuid::Uuid::new_v4(),
336 chrono::Utc::now().to_rfc3339(),
337 username,
338 (chrono::Utc::now() - chrono::Duration::minutes(1)).to_rfc3339(),
339 (chrono::Utc::now() + chrono::Duration::hours(1)).to_rfc3339(),
340 audience,
341 username,
342 username
343 );
344
345 Json(ApiResponse::success_with_message(
346 assertion_xml,
347 "SAML assertion created",
348 ))
349}
350
351pub async fn list_saml_idps(
353 State(_state): State<ApiState>,
354) -> Json<ApiResponse<Vec<serde_json::Value>>> {
355 let idps = vec![serde_json::json!({
357 "entity_id": "https://idp.example.com",
358 "certificate": "example_cert",
359 "sso_url": "https://idp.example.com/sso",
360 "slo_url": "https://idp.example.com/slo"
361 })];
362
363 Json(ApiResponse::success_with_message(
364 idps,
365 "SAML IdPs retrieved",
366 ))
367}
368
369fn extract_username_from_saml(saml_xml: &str) -> Result<String, String> {
371 if let Some(start) = saml_xml.find("<saml:NameID")
373 && let Some(content_start) = saml_xml[start..].find('>')
374 && let Some(end) = saml_xml[start + content_start + 1..].find("</saml:NameID>")
375 {
376 let username = &saml_xml[start + content_start + 1..start + content_start + 1 + end];
377 return Ok(username.trim().to_string());
378 }
379
380 Err("Could not extract username from SAML assertion".to_string())
381}
382
383fn extract_attributes_from_saml(saml_xml: &str) -> Result<HashMap<String, Vec<String>>, String> {
384 let mut attributes = HashMap::new();
385
386 if saml_xml.contains("<saml:AttributeStatement>") {
388 attributes.insert("source".to_string(), vec!["saml".to_string()]);
390 attributes.insert("auth_method".to_string(), vec!["saml_sso".to_string()]);
391 }
392
393 Ok(attributes)
394}