Skip to main content

qcs_api_client_common/configuration/
py.rs

1#![allow(unused_qualifications)]
2#![allow(non_local_definitions, reason = "necessary for pyo3::pymethods")]
3
4use pyo3::{
5    exceptions::PyValueError,
6    prelude::*,
7    types::{PyAnyMethods, PyString},
8};
9use rigetti_pyo3::{create_init_submodule, impl_repr, py_function_sync_async, sync::Awaitable};
10use tokio_util::sync::CancellationToken;
11
12#[cfg(feature = "stubs")]
13use pyo3_stub_gen::derive::{gen_stub_pyfunction, gen_stub_pymethods};
14
15use crate::configuration::{
16    secrets::{DEFAULT_SECRETS_PATH, SECRETS_PATH_VAR},
17    settings::{DEFAULT_SETTINGS_PATH, SETTINGS_PATH_VAR},
18    ClientConfigurationBuilderError, API_URL_VAR, DEFAULT_API_URL, DEFAULT_GRPC_API_URL,
19    DEFAULT_PROFILE_NAME, DEFAULT_QUILC_URL, DEFAULT_QVM_URL, GRPC_API_URL_VAR, PROFILE_NAME_VAR,
20    QUILC_URL_VAR, QVM_URL_VAR,
21};
22use crate::errors;
23
24use super::{
25    error::TokenError,
26    secrets::{SecretAccessToken, SecretRefreshToken},
27    settings::AuthServer,
28    tokens::{ClientCredentials, ClientSecret, ExternallyManaged, PkceFlow},
29    ClientConfiguration, ClientConfigurationBuilder, LoadError, OAuthGrant, OAuthSession,
30    RefreshToken, TokenDispatcher,
31};
32
33create_init_submodule! {
34    classes: [
35        ClientConfiguration,
36        ClientConfigurationBuilder,
37        AuthServer,
38        OAuthSession,
39        RefreshToken,
40        ClientCredentials,
41        ClientSecret,
42        ExternallyManaged,
43        PkceFlow,
44        SecretAccessToken,
45        SecretRefreshToken,
46        TokenDispatcher
47    ],
48
49    consts: [
50        API_URL_VAR,
51        DEFAULT_API_URL,
52        DEFAULT_GRPC_API_URL,
53        DEFAULT_PROFILE_NAME,
54        DEFAULT_QUILC_URL,
55        DEFAULT_QVM_URL,
56        DEFAULT_SECRETS_PATH,
57        DEFAULT_SETTINGS_PATH,
58        GRPC_API_URL_VAR,
59        PROFILE_NAME_VAR,
60        QUILC_URL_VAR,
61        QVM_URL_VAR,
62        SECRETS_PATH_VAR,
63        SETTINGS_PATH_VAR
64    ],
65
66    errors: [
67        errors::ClientConfigurationBuilderError,
68        errors::ConfigurationError,
69        errors::LoadError,
70        errors::TokenError
71    ],
72
73    funcs: [
74        py_get_oauth_session,
75        py_get_oauth_session_async,
76        py_get_bearer_access_token,
77        py_get_bearer_access_token_async,
78        py_request_access_token,
79        py_request_access_token_async
80    ],
81
82}
83
84#[cfg(feature = "stubs")]
85#[derive(IntoPyObject)]
86struct Final<T>(T);
87
88#[cfg(feature = "stubs")]
89impl<T> pyo3_stub_gen::PyStubType for Final<T> {
90    fn type_output() -> pyo3_stub_gen::TypeInfo {
91        pyo3_stub_gen::TypeInfo::with_module("typing.Final", "typing".into())
92    }
93}
94
95/// Adds module-level `str` to the `qcs_api_client_common.configuration` stub file.
96macro_rules! stub_consts {
97    ( $($name:ident),* ) => {
98        $(
99            #[cfg(feature = "stubs")]
100            ::pyo3_stub_gen::module_variable!(
101                "qcs_api_client_common.configuration",
102                stringify!($name),
103                Final<&str>,
104                Final($name)
105            );
106        )*
107    };
108}
109
110stub_consts!(
111    API_URL_VAR,
112    DEFAULT_API_URL,
113    DEFAULT_GRPC_API_URL,
114    DEFAULT_PROFILE_NAME,
115    DEFAULT_QUILC_URL,
116    DEFAULT_QVM_URL,
117    DEFAULT_SECRETS_PATH,
118    DEFAULT_SETTINGS_PATH,
119    GRPC_API_URL_VAR,
120    PROFILE_NAME_VAR,
121    QUILC_URL_VAR,
122    QVM_URL_VAR,
123    SECRETS_PATH_VAR,
124    SETTINGS_PATH_VAR
125);
126
127/// Manual implementation to extract tokens from Python objects.
128///
129/// For Python functions that require a `SecretRefreshToken`,
130/// users can provide a Python `str`, a `RefreshToken`, or a `SecretRefreshToken`.
131impl FromPyObject<'_, '_> for SecretRefreshToken {
132    type Error = PyErr;
133
134    fn extract(obj: Borrowed<'_, '_, PyAny>) -> Result<Self, Self::Error> {
135        if let Ok(token) = obj.cast::<PyString>() {
136            Ok(Self::__new__(token.extract()?))
137        } else if let Ok(token) = obj.cast::<RefreshToken>() {
138            Ok(token.borrow().refresh_token.clone())
139        } else if let Ok(token) = obj.cast::<Self>() {
140            Ok(token.borrow().clone())
141        } else {
142            Err(PyValueError::new_err(
143                "expected str | SecretRefreshToken | RefreshToken",
144            ))
145        }
146    }
147}
148
149impl FromPyObject<'_, '_> for SecretAccessToken {
150    type Error = PyErr;
151
152    fn extract(obj: Borrowed<'_, '_, PyAny>) -> Result<Self, Self::Error> {
153        if let Ok(token) = obj.cast::<PyString>() {
154            Ok(Self::__new__(token.extract()?))
155        } else if let Ok(token) = obj.cast::<Self>() {
156            Ok(token.borrow().clone())
157        } else {
158            Err(PyValueError::new_err("expected str | SecretAccessToken"))
159        }
160    }
161}
162
163impl FromPyObject<'_, '_> for ClientSecret {
164    type Error = PyErr;
165
166    fn extract(obj: Borrowed<'_, '_, PyAny>) -> Result<Self, Self::Error> {
167        if let Ok(token) = obj.cast::<PyString>() {
168            Ok(Self::__new__(token.extract()?))
169        } else if let Ok(token) = obj.cast::<Self>() {
170            Ok(token.borrow().clone())
171        } else {
172            Err(PyValueError::new_err("expected str | ClientSecret"))
173        }
174    }
175}
176
177impl_repr!(RefreshToken);
178
179#[cfg_attr(feature = "stubs", gen_stub_pymethods)]
180#[pymethods]
181impl RefreshToken {
182    #[new]
183    const fn __new__(refresh_token: SecretRefreshToken) -> Self {
184        Self::new(refresh_token)
185    }
186}
187
188impl_repr!(ClientCredentials);
189
190#[cfg_attr(feature = "stubs", gen_stub_pymethods)]
191#[pymethods]
192impl ClientCredentials {
193    #[new]
194    fn __new__(client_id: String, client_secret: String) -> Self {
195        Self::new(client_id, ClientSecret::from(client_secret))
196    }
197}
198
199impl_repr!(ExternallyManaged);
200
201#[cfg_attr(not(feature = "stubs"), optipy::strip_pyo3(only_stubs))]
202#[cfg_attr(feature = "stubs", gen_stub_pymethods)]
203#[pymethods]
204impl ExternallyManaged {
205    #[new]
206    fn __new__(
207        #[gen_stub(
208            override_type(
209                type_repr="collections.abc.Callable[[AuthServer], str]",
210                imports=("collections.abc")
211            )
212        )]
213        refresh_function: Bound<'_, PyAny>,
214    ) -> PyResult<Self> {
215        if !refresh_function.is_callable() {
216            return Err(pyo3::exceptions::PyTypeError::new_err(
217                "refresh_function must be callable",
218            ));
219        }
220
221        let refresh_function = refresh_function.unbind();
222
223        #[allow(trivial_casts)] // Compilation fails without the cast.
224        // The provided refresh function will panic if there is an issue with the refresh function.
225        // This raises a `PanicException` within Python.
226        let refresh_closure = move |auth_server: AuthServer| {
227            let refresh_function = Python::attach(|py| refresh_function.clone_ref(py));
228            Box::pin(async move {
229                Python::attach(|py| {
230                    let result = refresh_function.call1(py, (auth_server,));
231                    match result {
232                        Ok(value) => value
233                            .extract::<String>(py)
234                            .map_or_else(|_| panic!("ExternallyManaged refresh function returned an unexpected type. Expected a string, got {value:?}"), Ok),
235                        Err(err) => Err(Box::<dyn std::error::Error + Send + Sync>::from(err))
236                    }
237                })
238            }) as super::tokens::RefreshResult
239        };
240
241        Ok(Self::new(refresh_closure))
242    }
243}
244
245impl_repr!(PkceFlow);
246
247#[cfg_attr(feature = "stubs", gen_stub_pymethods)]
248#[pymethods]
249impl PkceFlow {
250    #[new]
251    fn __new__(py: Python<'_>, auth_server: AuthServer) -> PyResult<Self> {
252        pyo3_async_runtimes::tokio::run(py, async move {
253            let cancel_token = cancel_token_with_ctrl_c();
254            Self::new_login_flow(cancel_token, &auth_server)
255                .await
256                .map_err(|err| LoadError::from(err).into())
257        })
258    }
259}
260
261#[cfg(feature = "stubs")]
262pyo3_stub_gen::impl_stub_type!(
263    OAuthGrant = RefreshToken | ClientConfiguration | ExternallyManaged | PkceFlow
264);
265
266impl_repr!(OAuthSession);
267
268#[cfg_attr(feature = "stubs", gen_stub_pymethods)]
269#[pymethods]
270impl OAuthSession {
271    #[new]
272    #[pyo3(signature = (payload, auth_server, access_token = None))]
273    const fn __new__(
274        payload: OAuthGrant,
275        auth_server: AuthServer,
276        access_token: Option<SecretAccessToken>,
277    ) -> Self {
278        Self::new(payload, auth_server, access_token)
279    }
280
281    #[pyo3(name = "validate")]
282    fn py_validate(&self) -> Result<SecretAccessToken, TokenError> {
283        self.validate()
284    }
285
286    #[pyo3(name = "request_access_token")]
287    fn py_request_access_token(&self, py: Python<'_>) -> PyResult<SecretAccessToken> {
288        py_request_access_token(py, self.clone())
289    }
290
291    #[pyo3(name = "request_access_token_async")]
292    fn py_request_access_token_async<'py>(
293        &self,
294        py: Python<'py>,
295    ) -> PyResult<Awaitable<'py, SecretAccessToken>> {
296        py_request_access_token_async(py, self.clone())
297    }
298}
299
300py_function_sync_async! {
301    #[cfg_attr(feature = "stubs", gen_stub_pyfunction(module = "qcs_api_client_common.configuration"))]
302    #[pyfunction]
303    async fn get_oauth_session(tokens: Option<TokenDispatcher>) -> PyResult<OAuthSession> {
304        Ok(tokens.ok_or(TokenError::NoRefreshToken)?.tokens().await)
305    }
306}
307
308py_function_sync_async! {
309    /// Gets the `Bearer` access token, refreshing it if it is expired.
310    ///
311    /// # Errors
312    ///
313    /// Raises a `TokenError` if there's a problem providing the token.
314    #[cfg_attr(feature = "stubs", gen_stub_pyfunction(module = "qcs_api_client_common.configuration"))]
315    #[pyfunction]
316    async fn get_bearer_access_token(configuration: ClientConfiguration) -> PyResult<SecretAccessToken> {
317        configuration.get_bearer_access_token().await.map_err(PyErr::from)
318    }
319}
320
321py_function_sync_async! {
322    /// Request and return an updated access token using these credentials.
323    ///
324    /// # Errors
325    ///
326    /// Raises a `TokenError` if there's a problem providing the token.
327    #[cfg_attr(feature = "stubs", gen_stub_pyfunction(module = "qcs_api_client_common.configuration"))]
328    #[pyfunction]
329    async fn request_access_token(session: OAuthSession) -> PyResult<SecretAccessToken> {
330        session.clone().request_access_token().await.cloned().map_err(PyErr::from)
331    }
332}
333
334impl_repr!(ClientConfiguration);
335
336#[cfg_attr(feature = "stubs", gen_stub_pymethods)]
337#[pymethods]
338impl ClientConfiguration {
339    #[new]
340    #[pyo3(signature = (
341            api_url = None, grpc_api_url = None, quilc_url = None, qvm_url = None,
342            oauth_session = None,
343            ))]
344    fn __new__(
345        api_url: Option<String>,
346        grpc_api_url: Option<String>,
347        quilc_url: Option<String>,
348        qvm_url: Option<String>,
349        oauth_session: Option<OAuthSession>,
350    ) -> Self {
351        let mut builder = ClientConfigurationBuilder::default();
352
353        if let Some(api_url) = api_url {
354            builder.api_url(api_url);
355        }
356
357        if let Some(grpc_api_url) = grpc_api_url {
358            builder.grpc_api_url(grpc_api_url);
359        }
360
361        if let Some(quilc_url) = quilc_url {
362            builder.quilc_url(quilc_url);
363        }
364
365        if let Some(qvm_url) = qvm_url {
366            builder.qvm_url(qvm_url);
367        }
368
369        builder.oauth_session(oauth_session);
370
371        builder
372            .build()
373            .expect("our builder is valid regardless of which URLs are set")
374    }
375
376    #[staticmethod]
377    #[pyo3(name = "load_default")]
378    fn py_load_default(_py: Python<'_>) -> Result<Self, LoadError> {
379        Self::load_default()
380    }
381
382    #[staticmethod]
383    #[pyo3(name = "load_default_with_login")]
384    fn py_load_default_with_login(py: Python<'_>) -> PyResult<Self> {
385        pyo3_async_runtimes::tokio::run(py, async move {
386            let cancel_token = cancel_token_with_ctrl_c();
387            Self::load_with_login(cancel_token, None)
388                .await
389                .map_err(Into::into)
390        })
391    }
392
393    #[staticmethod]
394    #[pyo3(name = "builder")]
395    fn py_builder() -> ClientConfigurationBuilder {
396        ClientConfigurationBuilder::default()
397    }
398
399    #[staticmethod]
400    #[pyo3(name = "load_profile")]
401    fn py_load_profile(_py: Python<'_>, profile_name: String) -> Result<Self, LoadError> {
402        Self::load_profile(profile_name)
403    }
404
405    /// Gets the `Bearer` access token, refreshing it if it is expired.
406    ///
407    /// # Errors
408    ///
409    /// Raises a `TokenError` if there's a problem providing the token.
410    #[pyo3(name = "get_bearer_access_token")]
411    fn py_get_bearer_access_token(&self, py: Python<'_>) -> PyResult<SecretAccessToken> {
412        py_get_bearer_access_token(py, self.clone())
413    }
414
415    #[pyo3(name = "get_bearer_access_token_async")]
416    fn py_get_bearer_access_token_async<'py>(
417        &self,
418        py: Python<'py>,
419    ) -> PyResult<Awaitable<'py, SecretAccessToken>> {
420        py_get_bearer_access_token_async(py, self.clone())
421    }
422
423    /// Get the configured [`OAuthSession`].
424    ///
425    /// # Errors
426    ///
427    /// Raises a `TokenError` if there is a problem fetching the tokens.
428    pub fn get_oauth_session(&self, py: Python<'_>) -> PyResult<OAuthSession> {
429        py_get_oauth_session(py, self.oauth_session.clone())
430    }
431
432    fn get_oauth_session_async<'py>(
433        &self,
434        py: Python<'py>,
435    ) -> PyResult<Awaitable<'py, OAuthSession>> {
436        py_get_oauth_session_async(py, self.oauth_session.clone())
437    }
438}
439
440#[cfg_attr(feature = "stubs", gen_stub_pymethods)]
441#[pymethods]
442impl ClientConfigurationBuilder {
443    #[new]
444    fn __new__() -> Self {
445        Self::default()
446    }
447
448    /// The [`OAuthSession`] to use to authenticate with the QCS API.
449    ///
450    /// When set to [`None`], the configuration will not manage an OAuth Session, and access to the
451    /// QCS API will be limited to unauthenticated routes.
452    #[setter]
453    fn set_oauth_session(&mut self, oauth_session: Option<OAuthSession>) {
454        self.oauth_session = Some(oauth_session.map(Into::into));
455    }
456
457    #[pyo3(name = "build")]
458    fn py_build(&self) -> Result<ClientConfiguration, ClientConfigurationBuilderError> {
459        self.build()
460    }
461}
462
463impl_repr!(AuthServer);
464
465#[cfg_attr(feature = "stubs", gen_stub_pymethods)]
466#[pymethods]
467impl AuthServer {
468    #[new]
469    #[pyo3(signature = (client_id, issuer, scopes = None))]
470    const fn __new__(client_id: String, issuer: String, scopes: Option<Vec<String>>) -> Self {
471        Self::new(client_id, issuer, scopes)
472    }
473
474    #[staticmethod]
475    #[pyo3(name = "default")]
476    fn py_default() -> Self {
477        Self::default()
478    }
479}
480
481fn cancel_token_with_ctrl_c() -> CancellationToken {
482    let cancel_token = CancellationToken::new();
483    let cancel_token_ctrl_c = cancel_token.clone();
484    tokio::spawn(cancel_token.clone().run_until_cancelled_owned(async move {
485        match tokio::signal::ctrl_c().await {
486            Ok(()) => cancel_token_ctrl_c.cancel(),
487            Err(error) => eprintln!("Failed to register signal handler: {error}"),
488        }
489    }));
490    cancel_token
491}