llmvm_core_lib/
service.rs1use std::fs::create_dir;
2
3use async_stream::stream;
4use futures::StreamExt;
5use llmvm_protocol::{
6 error::ProtocolErrorType, service::BackendResponse, Core, GenerationRequest,
7 GenerationResponse, Message, NotificationStream, ProtocolError, ThreadInfo,
8};
9use llmvm_util::{get_file_path, DirType};
10use tracing::debug;
11
12use crate::{
13 error::CoreError,
14 threads::{
15 get_thread_infos, get_thread_messages, maybe_save_thread_messages_and_get_thread_id,
16 },
17 LLMVMCore, PROJECT_DIR_NAME,
18};
19
20#[llmvm_protocol::async_trait]
21impl Core for LLMVMCore {
22 async fn generate(
23 &self,
24 request: GenerationRequest,
25 ) -> std::result::Result<GenerationResponse, ProtocolError> {
26 async {
27 let (backend_request, model_description, thread_messages_to_save) =
28 self.prepare_for_generate(&request).await?;
29
30 let response = self
31 .send_generate_request(backend_request, &model_description)
32 .await?;
33
34 debug!("Response: {}", response.response);
35
36 let thread_id = maybe_save_thread_messages_and_get_thread_id(
37 &request,
38 response.response.clone(),
39 thread_messages_to_save,
40 )
41 .await?;
42
43 Ok(GenerationResponse {
44 response: response.response,
45 thread_id,
46 })
47 }
48 .await
49 .map_err(|e: CoreError| e.into())
50 }
51
52 async fn generate_stream(
53 &self,
54 request: GenerationRequest,
55 ) -> std::result::Result<NotificationStream<GenerationResponse>, ProtocolError> {
56 async {
57 let (backend_request, model_description, thread_messages_to_save) =
58 self.prepare_for_generate(&request).await?;
59
60 let mut stream = self
61 .send_generate_request_for_stream(backend_request, &model_description)
62 .await?;
63
64 Ok(stream! {
65 let mut full_response = String::new();
66 while let Some(result) = stream.next().await {
67 match result {
68 Ok(response) => match response {
69 BackendResponse::GenerationStream(response) => {
70 full_response.push_str(&response.response);
71 yield Ok(GenerationResponse {
72 response: response.response,
73 thread_id: None
74 });
75 }
76 _ => yield Err(CoreError::UnexpectedServiceResponse.into())
77 },
78 Err(e) => {
79 yield Err(e);
80 }
81 }
82 }
83 if let Ok(thread_id) = maybe_save_thread_messages_and_get_thread_id(&request, full_response, thread_messages_to_save).await {
84 yield Ok(GenerationResponse { response: String::new(), thread_id });
85 }
86 }.boxed())
87 }
88 .await
89 .map_err(|e: CoreError| e.into())
90 }
91
92 async fn get_last_thread_info(&self) -> std::result::Result<Option<ThreadInfo>, ProtocolError> {
93 async { Ok(get_thread_infos().await?.drain(0..1).next()) }
94 .await
95 .map_err(|e: CoreError| e.into())
96 }
97
98 async fn get_all_thread_infos(&self) -> std::result::Result<Vec<ThreadInfo>, ProtocolError> {
99 get_thread_infos().await.map_err(|e| e.into())
100 }
101
102 async fn get_thread_messages(
103 &self,
104 id: String,
105 ) -> std::result::Result<Vec<Message>, ProtocolError> {
106 get_thread_messages(&id).await.map_err(|e| e.into())
107 }
108
109 fn init_project(&self) -> std::result::Result<(), ProtocolError> {
110 create_dir(PROJECT_DIR_NAME).map_err(|error| ProtocolError {
111 error_type: ProtocolErrorType::Internal,
112 error: Box::new(error),
113 })?;
114 get_file_path(DirType::Prompts, "", true);
117 get_file_path(DirType::Presets, "", true);
118 get_file_path(DirType::Threads, "", true);
119 get_file_path(DirType::Logs, "", true);
120 get_file_path(DirType::Config, "", true);
121 get_file_path(DirType::Weights, "", true);
122 Ok(())
123 }
124}