gsm-core 0.4.21

Core types and platform abstractions for the Greentic messaging runtime.
Documentation
use std::sync::Arc;

use axum::{body::to_bytes, extract::State, response::IntoResponse};
use greentic_types::TenantCtx;
use gsm_core::platforms::webchat::{
    config::{Config, OAuthProviderConfig},
    conversation::memory_store,
    directline_client::{DirectLineError, MockDirectLineApi},
    http::{AppState, DirectLinePoster},
    oauth::{self, CLOSE_WINDOW_HTML, CallbackQuery, GreenticOauthClient, StartQuery},
    session::{MemorySessionStore, SharedSessionStore, WebchatSession, WebchatSessionStore},
};
use reqwest::Client;
use serde_json::Value;

#[path = "webchat_support.rs"]
mod support;

use support::{provider_with_secrets, signing_scope, tenant_ctx, tenant_scope};

#[tokio::test]
async fn oauth_start_redirects_to_authorize() {
    let direct_line = Arc::new(MockDirectLineApi::default());
    let client = Client::builder().build().unwrap();
    let store = Arc::new(MemorySessionStore::default());
    let sessions: SharedSessionStore = store.clone();
    store
        .upsert(WebchatSession::new(
            "conversation-123".to_string(),
            tenant_ctx("dev", "acme", None),
            "token-abc".to_string(),
        ))
        .await
        .unwrap();

    let oauth_scope = tenant_scope("dev", "acme", None);
    let secrets = [
        (
            &oauth_scope,
            "webchat_oauth",
            "issuer",
            "https://oauth.example.com",
        ),
        (&oauth_scope, "webchat_oauth", "client_id", "webchat-client"),
        (
            &oauth_scope,
            "webchat_oauth",
            "redirect_base",
            "https://webchat.example.com",
        ),
    ];
    let provider = provider_with_secrets(
        Config::with_base_url("https://directline.test/v3/directline"),
        signing_scope(),
        &secrets,
    );

    let state = AppState::new(provider.clone(), direct_line, client)
        .with_sessions(sessions)
        .with_activity_poster(Arc::new(NoopPoster))
        .with_oauth_client(Arc::new(StaticOauthClient));

    let response = oauth::start(
        State(state),
        axum::extract::Query(StartQuery {
            conversation_id: "conversation-123".into(),
            state: Some("xyz".into()),
        }),
    )
    .await
    .unwrap()
    .into_response();

    assert_eq!(
        response.status(),
        axum::http::StatusCode::TEMPORARY_REDIRECT
    );
    let location = response
        .headers()
        .get(axum::http::header::LOCATION)
        .unwrap();
    let url = reqwest::Url::parse(location.to_str().unwrap()).unwrap();
    assert_eq!(url.scheme(), "https");
    assert_eq!(url.host_str(), Some("oauth.example.com"));
    assert_eq!(url.path(), "/authorize");
    let params: std::collections::HashMap<_, _> = url.query_pairs().collect();
    assert_eq!(params["client_id"], "webchat-client");
    assert_eq!(params["state"], "xyz");
}

#[tokio::test]
async fn oauth_callback_exchanges_code_and_posts_handle() {
    let direct_line = Arc::new(MockDirectLineApi::default());
    let client = Client::builder().build().unwrap();
    let sessions: SharedSessionStore = Arc::new(MemorySessionStore::default());
    sessions
        .upsert(WebchatSession::new(
            "conversation-123".to_string(),
            tenant_ctx("dev", "acme", None),
            "token-abc".to_string(),
        ))
        .await
        .unwrap();

    let oauth_scope = tenant_scope("dev", "acme", None);
    let provider = provider_with_secrets(
        Config::with_base_url("https://directline.test/v3/directline"),
        signing_scope(),
        &[
            (
                &oauth_scope,
                "webchat_oauth",
                "issuer",
                "https://oauth.example.com",
            ),
            (&oauth_scope, "webchat_oauth", "client_id", "webchat-client"),
            (
                &oauth_scope,
                "webchat_oauth",
                "redirect_base",
                "https://webchat.example.com",
            ),
        ],
    );

    let conversations = memory_store();
    #[cfg(feature = "directline_standalone")]
    {
        conversations
            .create("conversation-123", tenant_ctx("dev", "acme", None))
            .await
            .unwrap();
    }

    let state = AppState::new(provider, direct_line, client)
        .with_sessions(Arc::clone(&sessions))
        .with_activity_poster(Arc::new(NoopPoster))
        .with_oauth_client(Arc::new(StaticOauthClient))
        .with_conversations(conversations.clone());

    let response = oauth::callback(
        State(state),
        axum::extract::Query(CallbackQuery {
            conversation_id: "conversation-123".into(),
            code: Some("auth-code-xyz".into()),
            state: None,
            error: None,
        }),
    )
    .await
    .unwrap()
    .into_response();

    assert_eq!(response.status(), axum::http::StatusCode::OK);
    let body = to_bytes(response.into_body(), 1024).await.unwrap();
    assert_eq!(body, CLOSE_WINDOW_HTML);

    #[cfg(feature = "directline_standalone")]
    {
        let page = conversations
            .activities("conversation-123", None)
            .await
            .unwrap();
        let bot_activity = page.activities.back().unwrap();
        assert_eq!(
            bot_activity.activity.text.as_deref(),
            Some("You're signed in.")
        );
        assert_eq!(
            bot_activity.activity.channel_data.as_ref().unwrap()["oauth_token_handle"],
            "token-handle-789"
        );
        let watermark = bot_activity.watermark;
        sessions
            .update_watermark("conversation-123", Some((watermark + 1).to_string()))
            .await
            .unwrap();
    }
}

struct NoopPoster;

#[async_trait::async_trait]
impl DirectLinePoster for NoopPoster {
    async fn post_activity(
        &self,
        _base_url: &str,
        _conversation_id: &str,
        _bearer_token: &str,
        _activity: Value,
    ) -> Result<(), DirectLineError> {
        Ok(())
    }
}

#[derive(Clone)]
struct StaticOauthClient;

#[async_trait::async_trait]
impl GreenticOauthClient for StaticOauthClient {
    async fn exchange_code(
        &self,
        _tenant_ctx: &TenantCtx,
        _config: &OAuthProviderConfig,
        _code: &str,
        _redirect_uri: &str,
    ) -> Result<String, anyhow::Error> {
        Ok("token-handle-789".into())
    }
}