use std::convert::Infallible;
use std::marker::PhantomData;
use std::str::FromStr;
use async_graphql::{
http::GraphiQLSource, Context, EmptySubscription, Object, Schema, SimpleObject,
};
use async_graphql_warp::{GraphQLBadRequest, GraphQLResponse};
use dotenvy::dotenv;
use http::StatusCode;
use myself::backend::core::AgentBackend;
use myself::backend::openai::OpenAIBackend;
use myself::database::memory::MemoryEngine;
use myself::sdk::agent::{Agent, AgentBuilder};
use myself::sdk::interaction::{Interaction as MyselfInteraction, InteractionState};
use std::env::var;
use uuid::Uuid;
use warp::{http::Response as HttpResponse, Filter, Rejection};
struct QueryRoot<Backend>
where
Backend: AgentBackend + Sized + Default + Clone,
{
_backend: PhantomData<Backend>,
}
struct MutationRoot<Backend>
where
Backend: AgentBackend + Sized + Default + Clone,
{
_backend: PhantomData<Backend>,
}
#[derive(SimpleObject)]
struct Interaction {
id: String,
user_name: String,
short_term_memory: String,
}
#[derive(SimpleObject)]
struct InteractionResponse {
response: String,
interaction: Interaction,
}
impl Interaction {
fn parse<Backend, State>(db_interaction: &MyselfInteraction<Backend, State>) -> Self
where
Backend: AgentBackend + Sized + Default + Clone,
State: InteractionState,
{
Self {
id: db_interaction.id.to_string(),
user_name: db_interaction.user_name.to_owned(),
short_term_memory: db_interaction.short_term_memory.to_owned(),
}
}
}
#[Object]
impl<Backend> QueryRoot<Backend>
where
Backend: AgentBackend + Sized + Default + Clone + Send + Sync + 'static,
{
async fn interactions<'a>(&self, ctx: &Context<'a>) -> Vec<Interaction>
where {
let mut agent = ctx.data::<Agent<Backend>>().unwrap().to_owned();
agent
.get_all_interactions()
.await
.iter()
.map(Interaction::parse)
.collect()
}
async fn interaction<'a>(&self, ctx: &Context<'a>, id: String) -> Option<Interaction> {
let mut agent = ctx.data::<Agent<Backend>>().unwrap().to_owned();
agent
.get_interaction(Uuid::from_str(id.as_str()).unwrap())
.await
.map(|i| Interaction::parse(&i))
}
}
#[Object]
impl<Backend> MutationRoot<Backend>
where
Backend: AgentBackend + Sized + Default + Clone + Send + Sync + 'static,
{
async fn new_interaction<'a>(
&self,
ctx: &Context<'a>,
user_name: String,
constitution: String,
memory_size: usize,
) -> Interaction {
let mut agent = ctx.data::<Agent<Backend>>().unwrap().to_owned();
let interaction = agent
.init_interaction(user_name, constitution, memory_size)
.await;
Interaction::parse(&interaction)
}
async fn interact_default<'a>(
&self,
ctx: &Context<'a>,
message: String,
) -> InteractionResponse {
let mut agent = ctx.data::<Agent<Backend>>().unwrap().to_owned();
let (_, response) = agent.interact_default(&message).await.unwrap();
InteractionResponse {
response: response.content,
interaction: Interaction::parse(&agent.get_default_interaction().await),
}
}
async fn interact<'a>(
&self,
ctx: &Context<'a>,
id: String,
message: String,
) -> InteractionResponse {
let mut agent = ctx.data::<Agent<Backend>>().unwrap().to_owned();
let uuid = Uuid::from_str(id.as_str()).unwrap();
let (_, response) = agent.interact(uuid, &message).await.unwrap();
InteractionResponse {
response: response.content,
interaction: Interaction::parse(&agent.get_interaction(uuid).await.unwrap()),
}
}
async fn update_constitution<'a>(
&self,
ctx: &Context<'a>,
id: String,
constitution: String,
) -> Interaction {
let mut agent = ctx.data::<Agent<Backend>>().unwrap().to_owned();
let interaction = agent
.update_long_term_memory(Uuid::from_str(id.as_str()).unwrap(), constitution)
.await;
Interaction::parse(&interaction)
}
async fn forget_memory<'a>(&self, ctx: &Context<'a>, id: String) -> Interaction {
let mut agent = ctx.data::<Agent<Backend>>().unwrap().to_owned();
let interaction = agent
.forgot_short_term_memory(Uuid::from_str(id.as_str()).unwrap())
.await;
Interaction::parse(&interaction.unwrap())
}
}
#[tokio::main]
async fn main() {
dotenv().ok();
let llm_engine = OpenAIBackend::new(var("OPENAI_API_KEY").unwrap());
let memory_engine = MemoryEngine::new(var("DATABASE_URL").unwrap()).await;
let agent = AgentBuilder::new()
.name("AI".to_string())
.build(llm_engine, memory_engine)
.await;
let schema = Schema::build(
QueryRoot {
_backend: PhantomData,
},
MutationRoot {
_backend: PhantomData,
},
EmptySubscription,
)
.data(agent)
.finish();
let graphql_post = async_graphql_warp::graphql(schema).and_then(
|(schema, request): (
Schema<QueryRoot<OpenAIBackend>, MutationRoot<OpenAIBackend>, EmptySubscription>,
async_graphql::Request,
)| async move {
Ok::<_, Infallible>(GraphQLResponse::from(schema.execute(request).await))
},
);
let graphiql = warp::path::end().and(warp::get()).map(|| {
HttpResponse::builder()
.header("content-type", "text/html")
.body(GraphiQLSource::build().endpoint("/").finish())
});
let routes = graphiql
.or(graphql_post)
.recover(|err: Rejection| async move {
if let Some(GraphQLBadRequest(err)) = err.find() {
return Ok::<_, Infallible>(warp::reply::with_status(
err.to_string(),
StatusCode::BAD_REQUEST,
));
}
Ok(warp::reply::with_status(
"INTERNAL_SERVER_ERROR".to_string(),
StatusCode::INTERNAL_SERVER_ERROR,
))
});
println!("GraphiQL IDE: http://localhost:8000");
warp::serve(routes).run(([127, 0, 0, 1], 8000)).await;
}