#![allow(refining_impl_trait_internal, refining_impl_trait_reachable)]
use std::sync::Arc;
use axum::{
Router as AxumRouter,
extract::{Path, Request},
http::StatusCode,
response::{IntoResponse, Json},
routing::get,
};
use futures::stream;
use serde_json::json;
use tower_http::cors::{Any, CorsLayer};
use sunbeam_g2v::error::ServiceResult;
use sunbeam_g2v::health::{HealthRouter, PermissionHealthCheck};
use sunbeam_g2v::middleware::auth;
use sunbeam_g2v::middleware::auth::{
ApiKeyContext, AuthContext, AuthMiddlewareState, TenantId, api_key::TenantApiKeyStore,
api_key::hash_api_key, authorization::AuthorizationClient, authorization::AuthorizationConfig,
introspection::IntrospectionConfig, introspection::IntrospectionSessionClient,
permission::PermissionLayer, session::IdentityMapping, session::IdentityMappingStore,
};
use sunbeam_g2v::router::ServiceRouter;
use sunbeam_g2v::server::{ServerConfig, builder::ServerBuilder};
use sunbeam_g2v::{RequestContext, Router as ConnectRouter};
mod eliza_proto {
include!(concat!(env!("OUT_DIR"), "/_eliza.rs"));
}
use eliza_proto::connectrpc::eliza::v1::{
ElizaService, ElizaServiceExt, IntroduceRequest, IntroduceResponse, SayRequest, SayResponse,
};
struct ElizaServiceImpl;
impl ElizaService for ElizaServiceImpl {
async fn say(
&self,
_ctx: RequestContext,
request: connectrpc::ServiceRequest<'_, SayRequest>,
) -> connectrpc::ServiceResult<SayResponse> {
let sentence = request.view().sentence.to_lowercase();
let reply = if sentence.contains("hello") || sentence.contains("hi") {
"Hello! I'm Eliza, your digital therapist. How are you feeling today?"
} else if sentence.contains("feel") || sentence.contains("feeling") {
"Tell me more about how you're feeling."
} else if sentence.contains("sad")
|| sentence.contains("unhappy")
|| sentence.contains("depressed")
{
"I'm sorry to hear that. What do you think is causing these feelings?"
} else if sentence.contains("happy")
|| sentence.contains("good")
|| sentence.contains("great")
{
"I'm glad to hear that! What's been making you feel this way?"
} else if sentence.contains("bye") || sentence.contains("goodbye") {
"Goodbye! Take care of yourself."
} else if sentence.contains("?") {
"That's an interesting question. What do you think the answer is?"
} else if sentence.is_empty() {
"I'm listening. Please, go on."
} else {
"Please, tell me more."
};
Ok(connectrpc::Response::new(SayResponse {
sentence: reply.to_string(),
..Default::default()
}))
}
async fn introduce(
&self,
_ctx: RequestContext,
request: connectrpc::ServiceRequest<'_, IntroduceRequest>,
) -> connectrpc::ServiceResult<connectrpc::ServiceStream<IntroduceResponse>> {
let name = request.view().name.to_string();
let responses = vec![
Ok(IntroduceResponse {
sentence: format!("Hi {name}! I'm Eliza, a digital therapist."),
..Default::default()
}),
Ok(IntroduceResponse {
sentence: "I'm here to listen and help you explore your thoughts.".to_string(),
..Default::default()
}),
Ok(IntroduceResponse {
sentence: "Feel free to tell me what's on your mind.".to_string(),
..Default::default()
}),
];
connectrpc::Response::stream_ok(stream::iter(responses))
}
}
struct ExampleApiKeyStore;
#[async_trait::async_trait]
impl TenantApiKeyStore for ExampleApiKeyStore {
async fn get_by_hash(
&self,
hash: &str,
) -> Result<
sunbeam_g2v::middleware::auth::api_key::ApiKeyRow,
sunbeam_g2v::middleware::auth::error::AuthError,
> {
let expected = hash_api_key("dev-key");
if hash == expected {
Ok(sunbeam_g2v::middleware::auth::api_key::ApiKeyRow {
id: "key-1".to_string(),
tenant_id: "01H0J0000000000000000000R".to_string(),
key_hash: expected,
name: "dev".to_string(),
scopes: vec!["secrets:read".to_string()],
})
} else {
Err(sunbeam_g2v::middleware::auth::error::AuthError::InvalidApiKey)
}
}
}
struct ExampleIdentityMappingStore;
#[async_trait::async_trait]
impl IdentityMappingStore for ExampleIdentityMappingStore {
async fn get_identity_mapping(
&self,
_backend: &str,
_identity_id: &str,
) -> Result<Option<IdentityMapping>, sunbeam_g2v::middleware::auth::error::AuthError> {
Ok(None)
}
}
async fn index_handler() -> impl IntoResponse {
Json(json!({ "service": "sunbeam-g2v simple example", "status": "ok" }))
}
async fn whoami_handler(request: Request) -> impl IntoResponse {
let tenant = request.extensions().get::<TenantId>().cloned();
let ctx = request.extensions().get::<AuthContext>().cloned();
let api_key = request.extensions().get::<ApiKeyContext>().cloned();
match (tenant, ctx) {
(Some(tenant), Some(ctx)) if ctx.is_authenticated() => Json(json!({
"tenant": tenant.0,
"subject": ctx.subject,
"scopes": ctx.scopes,
"key_id": api_key.map(|k| k.key_id),
}))
.into_response(),
_ => (
StatusCode::UNAUTHORIZED,
Json(json!({ "error": "unauthenticated" })),
)
.into_response(),
}
}
async fn secret_handler(Path(id): Path<String>, request: Request) -> impl IntoResponse {
let tenant = request.extensions().get::<TenantId>().cloned();
let ctx = request.extensions().get::<AuthContext>().cloned();
Json(json!({
"tenant": tenant.map(|t| t.0),
"subject": ctx.and_then(|c| c.subject),
"secret_id": id,
"message": "access granted"
}))
}
#[tokio::main]
async fn main() -> ServiceResult<()> {
let auth_read_url =
std::env::var("AUTH_READ_URL").unwrap_or_else(|_| "http://localhost:4466".to_string());
let auth_write_url =
std::env::var("AUTH_WRITE_URL").unwrap_or_else(|_| "http://localhost:4467".to_string());
let introspection_url = std::env::var("HYDRA_INTROSPECTION_URL")
.unwrap_or_else(|_| "http://localhost:4445/oauth2/introspect".to_string());
let introspection_client_id =
std::env::var("HYDRA_CLIENT_ID").unwrap_or_else(|_| "example-client".to_string());
let introspection_client_secret =
std::env::var("HYDRA_CLIENT_SECRET").unwrap_or_else(|_| "example-secret".to_string());
let bind_addr: std::net::SocketAddr = std::env::var("BIND_ADDR")
.unwrap_or_else(|_| "127.0.0.1:8080".to_string())
.parse()
.unwrap_or_else(|_| "127.0.0.1:8080".parse().unwrap());
let auth_state = AuthMiddlewareState {
api_keys: Arc::new(ExampleApiKeyStore),
sessions: Arc::new(IntrospectionSessionClient::new(IntrospectionConfig {
url: introspection_url,
client_id: introspection_client_id,
client_secret: introspection_client_secret,
})),
identity_mappings: Arc::new(ExampleIdentityMappingStore),
};
let authorization_client = Arc::new(
AuthorizationClient::new(AuthorizationConfig {
read_url: auth_read_url,
write_url: auth_write_url,
})
.unwrap_or_else(|e| panic!("invalid authorization config: {e}")),
);
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any)
.expose_headers(Any);
let eliza = Arc::new(ElizaServiceImpl);
let connect_router: ConnectRouter = eliza.register(ConnectRouter::new());
let service_router = ServiceRouter::from_router(connect_router);
let public_routes = AxumRouter::new().route("/", get(index_handler));
let authed_routes = AxumRouter::new()
.route("/whoami", get(whoami_handler))
.layer(axum::middleware::from_fn_with_state(
auth_state.clone(),
auth::auth_middleware,
));
let permission_layer = PermissionLayer::new((*authorization_client).clone(), "G2vTest", "read")
.with_object_extractor(Arc::new(|req: &http::Request<()>| {
req.uri()
.path()
.rsplit('/')
.next()
.unwrap_or("")
.to_string()
}));
let authzd_routes = AxumRouter::new()
.route("/g2v/secrets/{id}", get(secret_handler))
.layer(permission_layer)
.layer(axum::middleware::from_fn_with_state(
auth_state.clone(),
auth::auth_middleware,
));
let app_routes = public_routes.merge(authed_routes).merge(authzd_routes);
let health = HealthRouter::new().with_check(Arc::new(PermissionHealthCheck::new(
(*authorization_client).clone(),
)));
let config = ServerConfig {
addr: bind_addr,
name: "simple-example".to_string(),
..ServerConfig::default()
};
let sample_key = "dev-key";
let sample_tenant = "01ARYZ6S41TSV4RRFFQ69G5FAV";
println!("listening on http://{}", config.addr);
println!();
println!(" sample API key: {sample_key}");
println!(" sample tenant: {sample_tenant}");
println!();
println!(" curl http://{}/", config.addr);
println!(
" curl -H 'X-Api-Key: {sample_key}' -H 'X-Tenant-Id: {sample_tenant}' http://{}/whoami",
config.addr
);
println!(
" curl -H 'X-Api-Key: {sample_key}' -H 'X-Tenant-Id: {sample_tenant}' http://{}/g2v/secrets/42",
config.addr
);
println!(" curl http://{}/health/live", config.addr);
println!(" curl http://{}/health/ready", config.addr);
println!();
println!(" # Connect-RPC Eliza (no auth):");
println!(
" curl -X POST -H 'Content-Type: application/json' -d '{{\"sentence\":\"hello\"}}' \\",
);
println!(
" http://{}/connectrpc.eliza.v1.ElizaService/Say",
config.addr
);
println!();
let shutdown = async {
let _ = tokio::signal::ctrl_c().await;
println!("shutdown signal received");
};
let server = ServerBuilder::new()
.with_router(service_router)
.with_config(config)
.with_health(health)
.with_routes(app_routes)
.build_axum()?;
let app = server.app().layer(cors);
let listener = tokio::net::TcpListener::bind(server.config().addr)
.await
.map_err(|e| sunbeam_g2v::error::ServiceError::Internal(format!("bind: {e}")))?;
axum::serve(listener, app)
.with_graceful_shutdown(shutdown)
.await
.map_err(|e| sunbeam_g2v::error::ServiceError::Internal(format!("axum::serve: {e}")))
}