use axum::{body::Body, extract::State, http::StatusCode, response::Json};
use toi::{GenerationRequest, Message, MessageRole};
use tracing::{debug, info, warn};
use utoipa::openapi::OpenApi;
use utoipa_axum::{router::OpenApiRouter, routes};
use crate::{
models::{
assistant::{GeneratedCommandExtraction, GeneratedRequest, parse_generated_response},
client::{ApiClientError, EmbeddingPromptTemplate, EmbeddingRequest, RerankRequest},
openapi::{NewSearchableOpenApiPathItem, OpenApiPathItem, SearchableOpenApiPathItem},
prompts::{CommandPrompt, HttpRequestPrompt, SimplePrompt, SummaryPrompt, SystemPrompt},
state::ToiState,
},
schema, utils,
};
const INSTRUCTION_PREFIX: &str = "Instruction: Given a user query, retrieve RESTful API descriptions based on the command within the user's query";
const QUERY_PREFIX: &str = "Query: ";
pub async fn assistant_router(
openapi: &mut OpenApi,
state: ToiState,
) -> Result<OpenApiRouter, Box<dyn std::error::Error>> {
use diesel_async::RunQueryDsl;
let mut conn = state.pool.get().await?;
info!("preparing OpenAPI endpoints for automation");
let mut new_searchable_openapi_path_items = vec![];
diesel::delete(schema::openapi::table)
.execute(&mut conn)
.await?;
for (path, item) in &mut openapi.paths.paths {
if !path.contains('{') {
for (method, op) in [
("DELETE", &mut item.delete),
("GET", &mut item.get),
("POST", &mut item.post),
("PUT", &mut item.put),
] {
if let Some(op) = op {
let summary_and_description = [op.summary.clone(), op.description.clone()]
.into_iter()
.flatten()
.collect::<Vec<String>>()
.join("\n\n");
let descriptions: Vec<String> = summary_and_description
.lines()
.filter(|line| !line.trim().is_empty())
.map(|line| line.strip_prefix("- ").unwrap_or(line).to_string())
.collect();
if descriptions.is_empty() {
warn!(
"skipping uri={path} method={method} due to missing context from doc string"
);
continue;
}
let (params, body) = match op.extensions {
Some(ref mut extensions) => (
extensions.remove("x-json-schema-params"),
extensions.remove("x-json-schema-body"),
),
None => (None, None),
};
match (&op.parameters, ¶ms) {
(Some(_), Some(_)) | (None, None) => {}
(Some(_), None) => {
warn!(
"skipping uri={path} method={method} due to missing JSON schema parameters"
);
continue;
}
(None, Some(_)) => {
warn!(
"skipping uri={path} method={method} due to extra JSON schema parameters"
);
continue;
}
}
match (&op.request_body, &body) {
(Some(_), Some(_)) | (None, None) => {}
(Some(_), None) => {
warn!(
"skipping uri={path} method={method} due to missing JSON schema body"
);
continue;
}
(None, Some(_)) => {
warn!(
"skipping uri={path} method={method} due to extra JSON schema body"
);
continue;
}
}
info!("adding uri={path} method={method}");
let new_openapi_path_item = OpenApiPathItem {
path: path.to_string(),
method: method.to_string(),
description: summary_and_description,
params,
body,
};
let parent_id = diesel::insert_into(schema::openapi::table)
.values(&new_openapi_path_item)
.returning(schema::openapi::id)
.get_result(&mut conn)
.await?;
for description in descriptions {
debug!("processing line='{description}'");
let embedding_request = EmbeddingRequest {
input: description.clone(),
};
let embedding = state
.model_client
.embed(embedding_request)
.await
.map_err(|(_, err)| err)?;
new_searchable_openapi_path_items.push(NewSearchableOpenApiPathItem {
parent_id,
description,
embedding,
});
}
}
}
}
}
diesel::insert_into(schema::searchable_openapi::table)
.values(&new_searchable_openapi_path_items)
.execute(&mut conn)
.await?;
drop(conn);
let router = OpenApiRouter::new()
.routes(routes!(assist))
.with_state(state);
Ok(router)
}
#[utoipa::path(
post,
path = "",
request_body = GenerationRequest,
responses(
(status = 200, description = "Successfully got a response"),
(status = 400, description = "Default JSON elements configured by the user are invalid"),
(status = 422, description = "Error when parsing a response from a model API"),
(status = 502, description = "Error when forwarding request to model APIs")
)
)]
#[axum::debug_handler]
async fn assist(
State(state): State<ToiState>,
Json(mut request): Json<GenerationRequest>,
) -> Result<Body, (StatusCode, String)> {
let streaming_generation_request = if let Some(message) = request.messages.last() {
debug!(">> {}", message.content);
let system_prompt = CommandPrompt {};
let generation_request = GenerationRequest::builder()
.messages(system_prompt.to_messages(&request.messages))
.response_format(system_prompt.into_response_format())
.build();
debug!("preparing extraction request");
let generated_command_extraction = state.model_client.generate(generation_request).await?;
debug!("parsing extraction request");
let generated_command_extraction =
parse_generated_response::<GeneratedCommandExtraction>(&generated_command_extraction)?;
debug!("extraction={:?}", generated_command_extraction);
let GeneratedCommandExtraction { command, .. } = generated_command_extraction;
if let Some(command) = command {
debug!("embedding message for API search");
let input = EmbeddingPromptTemplate::builder()
.instruction_prefix(INSTRUCTION_PREFIX.to_string())
.query_prefix(QUERY_PREFIX.to_string())
.build()
.apply(&command);
let embedding_request = EmbeddingRequest { input };
let embedding = state.model_client.embed(embedding_request).await?;
let mut conn = state.pool.get().await.map_err(utils::internal_error)?;
let items: Vec<SearchableOpenApiPathItem> = {
use diesel::{QueryDsl, SelectableHelper};
use diesel_async::RunQueryDsl;
use pgvector::VectorExpressionMethods;
schema::searchable_openapi::table
.select(SearchableOpenApiPathItem::as_select())
.order(schema::searchable_openapi::embedding.cosine_distance(embedding))
.limit(16)
.load(&mut conn)
.await
.expect("should have some API items")
};
debug!("reranking API search results for relevance");
let (mut ids, documents): (Vec<i32>, Vec<String>) = items
.into_iter()
.map(|item| (item.parent_id, item.description))
.unzip();
let rerank_request = RerankRequest {
query: command,
documents,
};
let rerank_response = state.model_client.rerank(rerank_request).await?;
let most_relevant_result = &rerank_response.results[0];
let parent_id = ids.swap_remove(most_relevant_result.index);
let item: OpenApiPathItem = {
use diesel::{ExpressionMethods, QueryDsl, SelectableHelper};
use diesel_async::RunQueryDsl;
schema::openapi::table
.select(OpenApiPathItem::as_select())
.filter(schema::openapi::id.eq(parent_id))
.first(&mut conn)
.await
.expect("should find API item")
};
info!(
"most relevant API (uri={} method={}) scored at {:.3}",
item.path, item.method, most_relevant_result.relevance_score
);
if most_relevant_result.relevance_score >= state.server_config.similarity_threshold {
debug!("API passes similarity threshold");
let OpenApiPathItem {
path,
method,
description,
params,
body,
} = item;
let system_prompt = HttpRequestPrompt {
path,
method,
params,
body,
};
let generation_request = GenerationRequest::builder()
.messages(system_prompt.to_messages(&request.messages))
.response_format(system_prompt.into_response_format())
.build();
debug!("preparing proxy API request");
let generated_request = state.model_client.generate(generation_request).await?;
debug!("parsing proxy API request");
let generated_request =
parse_generated_response::<GeneratedRequest>(&generated_request)?;
debug!("proxy API request={:?}", generated_request);
let server_addr = format!("127.0.0.1:{}", &state.server_config.bind_addr.port());
let http_request =
generated_request.to_http_request(&state.api_client, &server_addr);
let assistant_message = generated_request.into_assistant_message();
request.messages.push(assistant_message);
debug!("sending proxy API request");
let response = state
.api_client
.execute(http_request)
.await
.map_err(|err| ApiClientError::ApiConnection.into_response(&err))?;
debug!("receiving proxy API response");
let content = response
.text()
.await
.unwrap_or_else(|err| format!("{err:?}"));
request.messages.push(Message {
role: MessageRole::User,
content,
});
debug!("summarizing API response");
SummaryPrompt { description }.to_streaming_generation_request(&request.messages)
} else {
debug!("no APIs pass similarity threshold");
SimplePrompt {}.to_streaming_generation_request(&request.messages)
}
} else {
warn!("no command found in request");
SimplePrompt {}.to_streaming_generation_request(&request.messages)
}
} else {
warn!("no message found in request");
SimplePrompt {}.to_streaming_generation_request(&request.messages)
};
debug!("beginning response stream");
let stream = state
.model_client
.generate_stream(streaming_generation_request)
.await?;
Ok(stream)
}