qcs_api_client_common/configuration/
py.rs

1#![allow(unused_qualifications)]
2
3use pyo3::{
4    exceptions::{PyFileNotFoundError, PyOSError, PyRuntimeError, PyValueError},
5    prelude::*,
6    types::PyFunction,
7};
8use pyo3_asyncio::tokio::get_runtime;
9use rigetti_pyo3::{create_init_submodule, py_function_sync_async};
10use tokio_util::sync::CancellationToken;
11
12use crate::{
13    configuration::{
14        PkceFlow, API_URL_VAR, DEFAULT_API_URL, DEFAULT_GRPC_API_URL, DEFAULT_PROFILE_NAME,
15        DEFAULT_QUILC_URL, DEFAULT_QVM_URL, DEFAULT_SECRETS_PATH, DEFAULT_SETTINGS_PATH,
16        GRPC_API_URL_VAR, PROFILE_NAME_VAR, QUILC_URL_VAR, QVM_URL_VAR, SECRETS_PATH_VAR,
17        SETTINGS_PATH_VAR,
18    },
19    impl_eq, impl_repr,
20};
21
22use super::{
23    error::TokenError, settings::AuthServer, tokens::ClientCredentials, ClientConfiguration,
24    ClientConfigurationBuilder, ExternallyManaged, LoadError, OAuthGrant, OAuthSession,
25    RefreshToken, TokenDispatcher,
26};
27
28create_init_submodule! {
29    classes: [
30        ClientConfiguration,
31        PyClientConfigurationBuilder,
32        AuthServer,
33        OAuthSession,
34        RefreshToken,
35        ClientCredentials,
36        ExternallyManaged,
37        PkceFlow
38    ],
39    consts: [
40        DEFAULT_API_URL,
41        DEFAULT_GRPC_API_URL,
42        DEFAULT_QUILC_URL,
43        DEFAULT_QVM_URL,
44        DEFAULT_PROFILE_NAME,
45        PROFILE_NAME_VAR,
46        QUILC_URL_VAR,
47        QVM_URL_VAR,
48        API_URL_VAR,
49        GRPC_API_URL_VAR,
50        SETTINGS_PATH_VAR,
51        DEFAULT_SETTINGS_PATH,
52        SECRETS_PATH_VAR,
53        DEFAULT_SECRETS_PATH
54    ],
55}
56
57impl_eq!(RefreshToken);
58impl_repr!(RefreshToken);
59#[pymethods]
60impl RefreshToken {
61    #[new]
62    const fn __new__(refresh_token: String) -> Self {
63        Self::new(refresh_token)
64    }
65
66    #[getter]
67    #[pyo3(name = "refresh_token")]
68    fn py_refresh_token(&self) -> &str {
69        &self.refresh_token
70    }
71
72    #[setter]
73    #[pyo3(name = "refresh_token")]
74    fn py_set_refresh_token(&mut self, refresh_token: String) {
75        self.refresh_token = refresh_token;
76    }
77}
78
79impl_eq!(ClientCredentials);
80// Does not implement `__repr__`, since the data contains a secret value.
81#[pymethods]
82impl ClientCredentials {
83    #[new]
84    const fn __new__(client_id: String, client_secret: String) -> Self {
85        Self::new(client_id, client_secret)
86    }
87
88    #[getter]
89    #[pyo3(name = "client_id")]
90    fn py_client_id(&self) -> &str {
91        self.client_id()
92    }
93
94    #[getter]
95    #[pyo3(name = "client_secret")]
96    fn py_client_secret(&self) -> &str {
97        self.client_secret()
98    }
99}
100
101impl_repr!(ExternallyManaged);
102#[pymethods]
103impl ExternallyManaged {
104    #[new]
105    fn __new__(refresh_function: Py<PyFunction>) -> Self {
106        #[allow(trivial_casts)] // Compilation fails without the cast.
107        // The provided refresh function will panic if there is an issue with the refresh function.
108        // This raises a `PanicException` within Python.
109        let refresh_closure = move |auth_server: AuthServer| {
110            let refresh_function = refresh_function.clone();
111            Box::pin(async move {
112                Python::with_gil(|py| {
113                    let result = refresh_function.call1(py, (auth_server.into_py(py),));
114                    match result {
115                        Ok(value) => value
116                            .extract::<String>(py)
117                            .map_or_else(|_| panic!("ExternallyManaged refresh function returned an unexpected type. Expected a string, got {value:?}"), Ok),
118                        Err(err) => Err(Box::<dyn std::error::Error + Send + Sync>::from(err))
119                    }
120                })
121            }) as super::tokens::RefreshResult
122        };
123
124        Self::new(refresh_closure)
125    }
126}
127
128impl_eq!(PkceFlow);
129// Does not implement `__repr__`, since the data contains a secret value.
130#[pymethods]
131impl PkceFlow {
132    #[new]
133    fn __new__(py: Python<'_>, auth_server: AuthServer) -> PyResult<Self> {
134        py.allow_threads(move || {
135            let runtime = get_runtime();
136            runtime.block_on(async move {
137                let cancel_token = cancel_token_with_ctrl_c();
138                PkceFlow::new_login_flow(cancel_token, &auth_server).await
139            })
140        })
141        .map_err(|err| PyRuntimeError::new_err(err.to_string()))
142    }
143
144    #[getter]
145    #[pyo3(name = "access_token")]
146    fn py_access_token(&self) -> &str {
147        &self.access_token
148    }
149
150    #[getter]
151    #[pyo3(name = "refresh_token")]
152    fn py_refresh_token(&self) -> Option<&str> {
153        self.refresh_token
154            .as_ref()
155            .map(|rt| rt.refresh_token.as_str())
156    }
157}
158
159impl_repr!(OAuthSession);
160#[pymethods]
161impl OAuthSession {
162    #[new]
163    const fn __new__(
164        payload: OAuthGrant,
165        auth_server: AuthServer,
166        access_token: Option<String>,
167    ) -> Self {
168        Self::new(payload, auth_server, access_token)
169    }
170
171    #[getter]
172    #[pyo3(name = "access_token")]
173    fn py_access_token(&self) -> Result<&str, TokenError> {
174        self.access_token()
175    }
176
177    #[getter]
178    #[pyo3(name = "payload")]
179    fn py_payload(&self, py: Python<'_>) -> PyObject {
180        match self.payload() {
181            OAuthGrant::ClientCredentials(ref client_credentials) => {
182                client_credentials.clone().into_py(py)
183            }
184            OAuthGrant::RefreshToken(ref refresh_token) => refresh_token.clone().into_py(py),
185            OAuthGrant::ExternallyManaged(ref externally_managed) => {
186                externally_managed.clone().into_py(py)
187            }
188            OAuthGrant::PkceFlow(ref pkce_tokens) => pkce_tokens.clone().into_py(py),
189        }
190    }
191
192    #[getter]
193    #[pyo3(name = "auth_server")]
194    fn py_auth_server(&self) -> AuthServer {
195        self.auth_server().clone()
196    }
197
198    #[pyo3(name = "validate")]
199    fn py_validate(&self) -> Result<String, TokenError> {
200        self.validate()
201    }
202
203    #[pyo3(name = "request_access_token")]
204    fn py_request_access_token(&self, py: Python<'_>) -> PyResult<String> {
205        py_request_access_token(py, self.clone())
206    }
207
208    #[pyo3(name = "request_access_token_async")]
209    fn py_request_access_token_async<'a>(&'a self, py: Python<'a>) -> PyResult<&'a PyAny> {
210        py_request_access_token_async(py, self.clone())
211    }
212}
213
214py_function_sync_async! {
215    #[pyfunction]
216    async fn get_oauth_session(tokens: Option<TokenDispatcher>) -> PyResult<OAuthSession> {
217        Ok(tokens.ok_or(TokenError::NoRefreshToken)?.tokens().await)
218    }
219}
220
221py_function_sync_async! {
222    #[pyfunction]
223    async fn get_bearer_access_token(configuration: ClientConfiguration) -> PyResult<String> {
224        configuration.get_bearer_access_token().await.map_err(PyErr::from)
225    }
226}
227
228py_function_sync_async! {
229    #[pyfunction]
230    async fn request_access_token(session: OAuthSession) -> PyResult<String> {
231        session.clone().request_access_token().await.map(std::string::ToString::to_string).map_err(PyErr::from)
232    }
233}
234
235impl_repr!(ClientConfiguration);
236#[pymethods]
237impl ClientConfiguration {
238    #[staticmethod]
239    #[pyo3(name = "load_default")]
240    fn py_load_default(_py: Python<'_>) -> Result<Self, LoadError> {
241        Self::load_default()
242    }
243
244    #[staticmethod]
245    #[pyo3(name = "load_default_with_login")]
246    fn py_load_default_with_login(py: Python<'_>) -> PyResult<Self> {
247        py.allow_threads(move || {
248            let runtime = get_runtime();
249            runtime.block_on(async move {
250                let cancel_token = cancel_token_with_ctrl_c();
251                Self::load_with_login(cancel_token, None).await
252            })
253        })
254        .map_err(|err| PyRuntimeError::new_err(err.to_string()))
255    }
256
257    #[staticmethod]
258    #[pyo3(name = "builder")]
259    fn py_builder() -> PyClientConfigurationBuilder {
260        PyClientConfigurationBuilder::default()
261    }
262
263    #[staticmethod]
264    #[pyo3(name = "load_profile")]
265    fn py_load_profile(_py: Python<'_>, profile_name: String) -> Result<Self, LoadError> {
266        Self::load_profile(profile_name)
267    }
268
269    #[getter]
270    fn get_api_url(&self) -> &str {
271        &self.api_url
272    }
273
274    #[getter]
275    fn get_grpc_api_url(&self) -> &str {
276        &self.grpc_api_url
277    }
278
279    #[getter]
280    fn get_quilc_url(&self) -> &str {
281        &self.quilc_url
282    }
283
284    #[getter]
285    fn get_qvm_url(&self) -> &str {
286        &self.qvm_url
287    }
288
289    #[pyo3(name = "get_bearer_access_token")]
290    fn py_get_bearer_access_token(&self, py: Python<'_>) -> PyResult<String> {
291        py_get_bearer_access_token(py, self.clone())
292    }
293
294    #[pyo3(name = "get_bearer_access_token_async")]
295    fn py_get_bearer_access_token_async<'a>(&self, py: Python<'a>) -> PyResult<&'a PyAny> {
296        py_get_bearer_access_token_async(py, self.clone())
297    }
298
299    /// Get the configured tokens.
300    ///
301    /// # Errors
302    ///
303    /// - Raises a `TokenError` if there is a problem fetching the tokens
304    pub fn get_oauth_session(&self, py: Python<'_>) -> PyResult<OAuthSession> {
305        py_get_oauth_session(py, self.oauth_session.clone())
306    }
307
308    #[allow(clippy::needless_pass_by_value)] // self_ must be passed by value
309    fn get_oauth_session_async<'a>(
310        self_: PyRefMut<'a, Self>,
311        py: Python<'a>,
312    ) -> PyResult<&'a PyAny> {
313        py_get_oauth_session_async(py, self_.oauth_session.clone())
314    }
315}
316
317#[pyclass]
318#[pyo3(name = "ClientConfigurationBuilder")]
319#[derive(Clone, Default)]
320struct PyClientConfigurationBuilder(ClientConfigurationBuilder);
321
322#[pymethods]
323impl PyClientConfigurationBuilder {
324    #[new]
325    fn new() -> Self {
326        Self::default()
327    }
328
329    fn build(&self) -> Result<ClientConfiguration, LoadError> {
330        Ok(self.0.build()?)
331    }
332
333    #[setter]
334    fn api_url(&mut self, api_url: String) {
335        self.0.api_url(api_url);
336    }
337
338    #[setter]
339    fn grpc_api_url(&mut self, grpc_api_url: String) {
340        self.0.grpc_api_url(grpc_api_url);
341    }
342
343    #[setter]
344    fn quilc_url(&mut self, quilc_url: String) {
345        self.0.quilc_url(quilc_url);
346    }
347
348    #[setter]
349    fn qvm_url(&mut self, qvm_url: String) {
350        self.0.qvm_url(qvm_url);
351    }
352
353    #[setter]
354    fn oauth_session(&mut self, oauth_session: Option<OAuthSession>) {
355        self.0.oauth_session(oauth_session);
356    }
357}
358
359impl_repr!(AuthServer);
360impl_eq!(AuthServer);
361#[pymethods]
362impl AuthServer {
363    #[new]
364    const fn py_new(client_id: String, issuer: String) -> Self {
365        Self::new(client_id, issuer)
366    }
367
368    #[staticmethod]
369    #[pyo3(name = "default")]
370    fn py_default() -> Self {
371        Self::default()
372    }
373
374    /// Get the configured Okta client id.
375    #[getter]
376    #[must_use]
377    pub fn get_client_id(&self) -> &str {
378        self.client_id()
379    }
380
381    /// Set an Okta client id.
382    #[setter(client_id)]
383    pub fn py_set_client_id(&mut self, id: String) {
384        self.set_client_id(id);
385    }
386
387    /// Get the Okta issuer URL.
388    #[getter]
389    #[must_use]
390    pub fn get_issuer(&self) -> &str {
391        self.issuer()
392    }
393
394    /// Set an Okta issuer URL.
395    #[setter(issuer)]
396    pub fn py_set_issuer(&mut self, issuer: String) {
397        self.set_issuer(issuer);
398    }
399}
400
401impl From<LoadError> for PyErr {
402    fn from(value: LoadError) -> Self {
403        let message = value.to_string();
404        match value {
405            LoadError::Load(_)
406            | LoadError::Build(_)
407            | LoadError::ProfileNotFound(_)
408            | LoadError::AuthServerNotFound(_)
409            | LoadError::PkceFlow(_) => PyValueError::new_err(message),
410            LoadError::EnvVar { .. } | LoadError::Io(_) => PyOSError::new_err(message),
411            LoadError::Path { .. } => PyFileNotFoundError::new_err(message),
412            #[cfg(feature = "tracing-config")]
413            LoadError::TracingFilterParseError(_) => PyValueError::new_err(message),
414        }
415    }
416}
417
418impl From<TokenError> for PyErr {
419    fn from(value: TokenError) -> Self {
420        let message = value.to_string();
421        match value {
422            TokenError::NoRefreshToken
423            | TokenError::NoCredentials
424            | TokenError::NoAccessToken
425            | TokenError::NoAuthServer
426            | TokenError::InvalidAccessToken(_)
427            | TokenError::Fetch(_)
428            | TokenError::ExternallyManaged(_)
429            | TokenError::Write(_)
430            | TokenError::Discovery(_) => PyValueError::new_err(message),
431        }
432    }
433}
434
435fn cancel_token_with_ctrl_c() -> CancellationToken {
436    let cancel_token = CancellationToken::new();
437    let cancel_token_ctrl_c = cancel_token.clone();
438    tokio::spawn(cancel_token.clone().run_until_cancelled_owned(async move {
439        match tokio::signal::ctrl_c().await {
440            Ok(()) => cancel_token_ctrl_c.cancel(),
441            Err(error) => eprintln!("Failed to register signal handler: {error}"),
442        }
443    }));
444    cancel_token
445}