mod device_code_responses;
use azure_core::{
content_type,
error::{Error, ErrorKind},
from_json, headers, sleep, HttpClient, Method, Request, Response, Url,
};
pub use device_code_responses::*;
use futures::stream::unfold;
use serde::Deserialize;
use std::{borrow::Cow, pin::Pin, sync::Arc, time::Duration};
use url::form_urlencoded;
pub async fn start<'a, 'b, T>(
http_client: Arc<dyn HttpClient>,
tenant_id: T,
client_id: &str,
scopes: &'b [&'b str],
) -> azure_core::Result<DeviceCodePhaseOneResponse<'a>>
where
T: Into<Cow<'a, str>>,
{
let tenant_id = tenant_id.into();
let url = &format!("https://login.microsoftonline.com/{tenant_id}/oauth2/v2.0/devicecode");
let encoded = form_urlencoded::Serializer::new(String::new())
.append_pair("client_id", client_id)
.append_pair("scope", &scopes.join(" "))
.finish();
let rsp = post_form(http_client.clone(), url, encoded).await?;
let (rsp_status, rsp_headers, rsp_body) = rsp.deconstruct();
let rsp_body = rsp_body.collect().await?;
if !rsp_status.is_success() {
return Err(
ErrorKind::http_response_from_parts(rsp_status, &rsp_headers, &rsp_body).into_error(),
);
}
let device_code_response: DeviceCodePhaseOneResponse = from_json(&rsp_body)?;
Ok(DeviceCodePhaseOneResponse {
device_code: device_code_response.device_code,
user_code: device_code_response.user_code,
verification_uri: device_code_response.verification_uri,
expires_in: device_code_response.expires_in,
interval: device_code_response.interval,
message: device_code_response.message,
http_client: Some(http_client),
tenant_id,
client_id: client_id.to_string(),
})
}
#[derive(Debug, Clone, Deserialize)]
pub struct DeviceCodePhaseOneResponse<'a> {
device_code: String,
user_code: String,
verification_uri: String,
expires_in: u64,
interval: u64,
message: String,
#[serde(skip)]
http_client: Option<Arc<dyn HttpClient>>,
#[serde(skip)]
tenant_id: Cow<'a, str>,
#[serde(skip)]
client_id: String,
}
impl<'a> DeviceCodePhaseOneResponse<'a> {
pub fn message(&self) -> &str {
&self.message
}
pub fn stream(
&self,
) -> Pin<Box<impl futures::Stream<Item = azure_core::Result<DeviceCodeAuthorization>> + '_>>
{
#[derive(Debug, Clone, PartialEq, Eq)]
enum NextState {
Continue,
Finish,
}
Box::pin(unfold(
NextState::Continue,
move |state: NextState| async move {
match state {
NextState::Continue => {
let url = &format!(
"https://login.microsoftonline.com/{}/oauth2/v2.0/token",
self.tenant_id,
);
sleep(Duration::from_secs(self.interval)).await;
let encoded = form_urlencoded::Serializer::new(String::new())
.append_pair(
"grant_type",
"urn:ietf:params:oauth:grant-type:device_code",
)
.append_pair("client_id", self.client_id.as_str())
.append_pair("device_code", &self.device_code)
.finish();
let http_client = self.http_client.clone().unwrap();
match post_form(http_client.clone(), url, encoded).await {
Ok(rsp) => {
let rsp_status = rsp.status();
let rsp_body = match rsp.into_body().collect().await {
Ok(b) => b,
Err(e) => return Some((Err(e), NextState::Finish)),
};
if rsp_status.is_success() {
match from_json::<_, DeviceCodeAuthorization>(&rsp_body) {
Ok(authorization) => {
Some((Ok(authorization), NextState::Finish))
}
Err(error) => Some((Err(error), NextState::Finish)),
}
} else {
match from_json::<_, DeviceCodeErrorResponse>(&rsp_body) {
Ok(error_rsp) => {
let next_state =
if error_rsp.error == "authorization_pending" {
NextState::Continue
} else {
NextState::Finish
};
Some((
Err(Error::new(ErrorKind::Credential, error_rsp)),
next_state,
))
}
Err(error) => Some((Err(error), NextState::Finish)),
}
}
}
Err(error) => Some((Err(error), NextState::Finish)),
}
}
NextState::Finish => None,
}
},
))
}
}
async fn post_form(
http_client: Arc<dyn HttpClient>,
url: &str,
form_body: String,
) -> azure_core::Result<Response> {
let url = Url::parse(url)?;
let mut req = Request::new(url, Method::Post);
req.insert_header(
headers::CONTENT_TYPE,
content_type::APPLICATION_X_WWW_FORM_URLENCODED,
);
req.set_body(form_body);
http_client.execute_request(&req).await
}
#[cfg(test)]
mod tests {
use super::*;
fn require_send<T: Send>(_t: T) {}
#[test]
fn ensure_that_start_is_send() {
require_send(start(
azure_core::new_http_client(),
"UNUSED",
"UNUSED",
&[],
));
}
}